This commit is contained in:
mike
2025-12-13 11:53:29 +01:00
parent 2bd4f93777
commit 5098f5b291
8 changed files with 100 additions and 806 deletions

View File

@@ -1,8 +1,6 @@
"""ML-based classification (optional, using sklearn if available)"""
from pathlib import Path
from typing import Optional, List, Tuple
import pickle
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
@@ -11,100 +9,41 @@ try:
except ImportError:
SKLEARN_AVAILABLE = False
class MLClassifier:
"""Machine learning-based file classifier
Uses path-based features and optional metadata to classify files.
Requires scikit-learn to be installed.
"""
def __init__(self):
"""Initialize ML classifier"""
if not SKLEARN_AVAILABLE:
raise ImportError(
"scikit-learn is required for ML classification. "
"Install with: pip install scikit-learn"
)
raise ImportError('scikit-learn is required for ML classification. Install with: pip install scikit-learn')
self.model: Optional[Pipeline] = None
self.categories: List[str] = []
self._is_trained = False
def _extract_features(self, path: Path) -> str:
"""Extract features from path
Args:
path: Path to extract features from
Returns:
Feature string
"""
# Convert path to feature string
# Include: path parts, extension, filename
parts = path.parts
extension = path.suffix
filename = path.name
features = []
# Add path components
features.extend(parts)
# Add extension
if extension:
features.append(f"ext:{extension}")
# Add filename components (split on common separators)
features.append(f'ext:{extension}')
name_parts = filename.replace('-', ' ').replace('_', ' ').replace('.', ' ').split()
features.extend([f"name:{part}" for part in name_parts])
features.extend([f'name:{part}' for part in name_parts])
return ' '.join(features)
def train(self, training_data: List[Tuple[Path, str]]) -> None:
"""Train the classifier
Args:
training_data: List of (path, category) tuples
"""
if not training_data:
raise ValueError("Training data cannot be empty")
# Extract features and labels
raise ValueError('Training data cannot be empty')
X = [self._extract_features(path) for path, _ in training_data]
y = [category for _, category in training_data]
# Store unique categories
self.categories = sorted(set(y))
# Create and train pipeline
self.model = Pipeline([
('tfidf', TfidfVectorizer(
max_features=1000,
ngram_range=(1, 2),
min_df=1
)),
('classifier', MultinomialNB())
])
self.model = Pipeline([('tfidf', TfidfVectorizer(max_features=1000, ngram_range=(1, 2), min_df=1)), ('classifier', MultinomialNB())])
self.model.fit(X, y)
self._is_trained = True
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
"""Classify a file path
Args:
path: Path to classify
file_type: Optional file type hint (not used in ML classifier)
Returns:
Category name or None if not trained
"""
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
if not self._is_trained or self.model is None:
return None
features = self._extract_features(path)
try:
prediction = self.model.predict([features])[0]
return prediction
@@ -112,158 +51,77 @@ class MLClassifier:
return None
def predict_proba(self, path: Path) -> dict[str, float]:
"""Get prediction probabilities for all categories
Args:
path: Path to classify
Returns:
Dictionary mapping category to probability
"""
if not self._is_trained or self.model is None:
return {}
features = self._extract_features(path)
try:
probabilities = self.model.predict_proba([features])[0]
return {
category: float(prob)
for category, prob in zip(self.categories, probabilities)
}
return {category: float(prob) for category, prob in zip(self.categories, probabilities)}
except Exception:
return {}
def save_model(self, model_path: Path) -> None:
"""Save trained model to disk
Args:
model_path: Path to save model
"""
if not self._is_trained:
raise ValueError("Cannot save untrained model")
model_data = {
'model': self.model,
'categories': self.categories,
'is_trained': self._is_trained
}
raise ValueError('Cannot save untrained model')
model_data = {'model': self.model, 'categories': self.categories, 'is_trained': self._is_trained}
with open(model_path, 'wb') as f:
pickle.dump(model_data, f)
def load_model(self, model_path: Path) -> None:
"""Load trained model from disk
Args:
model_path: Path to model file
"""
with open(model_path, 'rb') as f:
model_data = pickle.load(f)
self.model = model_data['model']
self.categories = model_data['categories']
self._is_trained = model_data['is_trained']
@property
def is_trained(self) -> bool:
"""Check if model is trained"""
return self._is_trained
class DummyMLClassifier:
"""Dummy ML classifier for when sklearn is not available"""
def __init__(self):
"""Initialize dummy classifier"""
pass
def train(self, training_data: List[Tuple[Path, str]]) -> None:
"""Dummy train method"""
raise NotImplementedError(
"ML classification requires scikit-learn. "
"Install with: pip install scikit-learn"
)
raise NotImplementedError('ML classification requires scikit-learn. Install with: pip install scikit-learn')
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
"""Dummy classify method"""
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
return None
def predict_proba(self, path: Path) -> dict[str, float]:
"""Dummy predict_proba method"""
return {}
def save_model(self, model_path: Path) -> None:
"""Dummy save_model method"""
raise NotImplementedError("ML classification not available")
raise NotImplementedError('ML classification not available')
def load_model(self, model_path: Path) -> None:
"""Dummy load_model method"""
raise NotImplementedError("ML classification not available")
raise NotImplementedError('ML classification not available')
@property
def is_trained(self) -> bool:
"""Check if model is trained"""
return False
def create_ml_classifier() -> MLClassifier | DummyMLClassifier:
"""Create ML classifier if sklearn is available, otherwise return dummy
Returns:
MLClassifier or DummyMLClassifier
"""
if SKLEARN_AVAILABLE:
return MLClassifier()
else:
return DummyMLClassifier()
def train_from_database(
db_connection,
min_samples_per_category: int = 10
) -> MLClassifier | DummyMLClassifier:
"""Train ML classifier from database
Args:
db_connection: Database connection
min_samples_per_category: Minimum samples required per category
Returns:
Trained classifier
"""
def train_from_database(db_connection, min_samples_per_category: int=10) -> MLClassifier | DummyMLClassifier:
classifier = create_ml_classifier()
if isinstance(classifier, DummyMLClassifier):
return classifier
# Query classified files from database
cursor = db_connection.cursor()
cursor.execute("""
SELECT path, category
FROM files
WHERE category IS NOT NULL
""")
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
cursor.close()
if not training_data:
return classifier
# Count samples per category
category_counts = {}
for _, category in training_data:
category_counts[category] = category_counts.get(category, 0) + 1
# Filter to categories with enough samples
filtered_data = [
(path, category)
for path, category in training_data
if category_counts[category] >= min_samples_per_category
]
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples_per_category]
if filtered_data:
classifier.train(filtered_data)
return classifier