270 lines
7.4 KiB
Python
270 lines
7.4 KiB
Python
"""ML-based classification (optional, using sklearn if available)"""
|
|
from pathlib import Path
|
|
from typing import Optional, List, Tuple
|
|
import pickle
|
|
|
|
try:
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.naive_bayes import MultinomialNB
|
|
from sklearn.pipeline import Pipeline
|
|
SKLEARN_AVAILABLE = True
|
|
except ImportError:
|
|
SKLEARN_AVAILABLE = False
|
|
|
|
|
|
class MLClassifier:
|
|
"""Machine learning-based file classifier
|
|
|
|
Uses path-based features and optional metadata to classify files.
|
|
Requires scikit-learn to be installed.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize ML classifier"""
|
|
if not SKLEARN_AVAILABLE:
|
|
raise ImportError(
|
|
"scikit-learn is required for ML classification. "
|
|
"Install with: pip install scikit-learn"
|
|
)
|
|
|
|
self.model: Optional[Pipeline] = None
|
|
self.categories: List[str] = []
|
|
self._is_trained = False
|
|
|
|
def _extract_features(self, path: Path) -> str:
|
|
"""Extract features from path
|
|
|
|
Args:
|
|
path: Path to extract features from
|
|
|
|
Returns:
|
|
Feature string
|
|
"""
|
|
# Convert path to feature string
|
|
# Include: path parts, extension, filename
|
|
parts = path.parts
|
|
extension = path.suffix
|
|
filename = path.name
|
|
|
|
features = []
|
|
|
|
# Add path components
|
|
features.extend(parts)
|
|
|
|
# Add extension
|
|
if extension:
|
|
features.append(f"ext:{extension}")
|
|
|
|
# Add filename components (split on common separators)
|
|
name_parts = filename.replace('-', ' ').replace('_', ' ').replace('.', ' ').split()
|
|
features.extend([f"name:{part}" for part in name_parts])
|
|
|
|
return ' '.join(features)
|
|
|
|
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
|
"""Train the classifier
|
|
|
|
Args:
|
|
training_data: List of (path, category) tuples
|
|
"""
|
|
if not training_data:
|
|
raise ValueError("Training data cannot be empty")
|
|
|
|
# Extract features and labels
|
|
X = [self._extract_features(path) for path, _ in training_data]
|
|
y = [category for _, category in training_data]
|
|
|
|
# Store unique categories
|
|
self.categories = sorted(set(y))
|
|
|
|
# Create and train pipeline
|
|
self.model = Pipeline([
|
|
('tfidf', TfidfVectorizer(
|
|
max_features=1000,
|
|
ngram_range=(1, 2),
|
|
min_df=1
|
|
)),
|
|
('classifier', MultinomialNB())
|
|
])
|
|
|
|
self.model.fit(X, y)
|
|
self._is_trained = True
|
|
|
|
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
|
|
"""Classify a file path
|
|
|
|
Args:
|
|
path: Path to classify
|
|
file_type: Optional file type hint (not used in ML classifier)
|
|
|
|
Returns:
|
|
Category name or None if not trained
|
|
"""
|
|
if not self._is_trained or self.model is None:
|
|
return None
|
|
|
|
features = self._extract_features(path)
|
|
|
|
try:
|
|
prediction = self.model.predict([features])[0]
|
|
return prediction
|
|
except Exception:
|
|
return None
|
|
|
|
def predict_proba(self, path: Path) -> dict[str, float]:
|
|
"""Get prediction probabilities for all categories
|
|
|
|
Args:
|
|
path: Path to classify
|
|
|
|
Returns:
|
|
Dictionary mapping category to probability
|
|
"""
|
|
if not self._is_trained or self.model is None:
|
|
return {}
|
|
|
|
features = self._extract_features(path)
|
|
|
|
try:
|
|
probabilities = self.model.predict_proba([features])[0]
|
|
return {
|
|
category: float(prob)
|
|
for category, prob in zip(self.categories, probabilities)
|
|
}
|
|
except Exception:
|
|
return {}
|
|
|
|
def save_model(self, model_path: Path) -> None:
|
|
"""Save trained model to disk
|
|
|
|
Args:
|
|
model_path: Path to save model
|
|
"""
|
|
if not self._is_trained:
|
|
raise ValueError("Cannot save untrained model")
|
|
|
|
model_data = {
|
|
'model': self.model,
|
|
'categories': self.categories,
|
|
'is_trained': self._is_trained
|
|
}
|
|
|
|
with open(model_path, 'wb') as f:
|
|
pickle.dump(model_data, f)
|
|
|
|
def load_model(self, model_path: Path) -> None:
|
|
"""Load trained model from disk
|
|
|
|
Args:
|
|
model_path: Path to model file
|
|
"""
|
|
with open(model_path, 'rb') as f:
|
|
model_data = pickle.load(f)
|
|
|
|
self.model = model_data['model']
|
|
self.categories = model_data['categories']
|
|
self._is_trained = model_data['is_trained']
|
|
|
|
@property
|
|
def is_trained(self) -> bool:
|
|
"""Check if model is trained"""
|
|
return self._is_trained
|
|
|
|
|
|
class DummyMLClassifier:
|
|
"""Dummy ML classifier for when sklearn is not available"""
|
|
|
|
def __init__(self):
|
|
"""Initialize dummy classifier"""
|
|
pass
|
|
|
|
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
|
"""Dummy train method"""
|
|
raise NotImplementedError(
|
|
"ML classification requires scikit-learn. "
|
|
"Install with: pip install scikit-learn"
|
|
)
|
|
|
|
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
|
|
"""Dummy classify method"""
|
|
return None
|
|
|
|
def predict_proba(self, path: Path) -> dict[str, float]:
|
|
"""Dummy predict_proba method"""
|
|
return {}
|
|
|
|
def save_model(self, model_path: Path) -> None:
|
|
"""Dummy save_model method"""
|
|
raise NotImplementedError("ML classification not available")
|
|
|
|
def load_model(self, model_path: Path) -> None:
|
|
"""Dummy load_model method"""
|
|
raise NotImplementedError("ML classification not available")
|
|
|
|
@property
|
|
def is_trained(self) -> bool:
|
|
"""Check if model is trained"""
|
|
return False
|
|
|
|
|
|
def create_ml_classifier() -> MLClassifier | DummyMLClassifier:
|
|
"""Create ML classifier if sklearn is available, otherwise return dummy
|
|
|
|
Returns:
|
|
MLClassifier or DummyMLClassifier
|
|
"""
|
|
if SKLEARN_AVAILABLE:
|
|
return MLClassifier()
|
|
else:
|
|
return DummyMLClassifier()
|
|
|
|
|
|
def train_from_database(
|
|
db_connection,
|
|
min_samples_per_category: int = 10
|
|
) -> MLClassifier | DummyMLClassifier:
|
|
"""Train ML classifier from database
|
|
|
|
Args:
|
|
db_connection: Database connection
|
|
min_samples_per_category: Minimum samples required per category
|
|
|
|
Returns:
|
|
Trained classifier
|
|
"""
|
|
classifier = create_ml_classifier()
|
|
|
|
if isinstance(classifier, DummyMLClassifier):
|
|
return classifier
|
|
|
|
# Query classified files from database
|
|
cursor = db_connection.cursor()
|
|
cursor.execute("""
|
|
SELECT path, category
|
|
FROM files
|
|
WHERE category IS NOT NULL
|
|
""")
|
|
|
|
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
|
cursor.close()
|
|
|
|
if not training_data:
|
|
return classifier
|
|
|
|
# Count samples per category
|
|
category_counts = {}
|
|
for _, category in training_data:
|
|
category_counts[category] = category_counts.get(category, 0) + 1
|
|
|
|
# Filter to categories with enough samples
|
|
filtered_data = [
|
|
(path, category)
|
|
for path, category in training_data
|
|
if category_counts[category] >= min_samples_per_category
|
|
]
|
|
|
|
if filtered_data:
|
|
classifier.train(filtered_data)
|
|
|
|
return classifier
|