initial
This commit is contained in:
127
app/classification/ml.py
Normal file
127
app/classification/ml.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple
|
||||
import pickle
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.pipeline import Pipeline
|
||||
SKLEARN_AVAILABLE = True
|
||||
except ImportError:
|
||||
SKLEARN_AVAILABLE = False
|
||||
|
||||
class MLClassifier:
|
||||
|
||||
def __init__(self):
|
||||
if not SKLEARN_AVAILABLE:
|
||||
raise ImportError('scikit-learn is required for ML classification. Install with: pip install scikit-learn')
|
||||
self.model: Optional[Pipeline] = None
|
||||
self.categories: List[str] = []
|
||||
self._is_trained = False
|
||||
|
||||
def _extract_features(self, path: Path) -> str:
|
||||
parts = path.parts
|
||||
extension = path.suffix
|
||||
filename = path.name
|
||||
features = []
|
||||
features.extend(parts)
|
||||
if extension:
|
||||
features.append(f'ext:{extension}')
|
||||
name_parts = filename.replace('-', ' ').replace('_', ' ').replace('.', ' ').split()
|
||||
features.extend([f'name:{part}' for part in name_parts])
|
||||
return ' '.join(features)
|
||||
|
||||
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
||||
if not training_data:
|
||||
raise ValueError('Training data cannot be empty')
|
||||
X = [self._extract_features(path) for path, _ in training_data]
|
||||
y = [category for _, category in training_data]
|
||||
self.categories = sorted(set(y))
|
||||
self.model = Pipeline([('tfidf', TfidfVectorizer(max_features=1000, ngram_range=(1, 2), min_df=1)), ('classifier', MultinomialNB())])
|
||||
self.model.fit(X, y)
|
||||
self._is_trained = True
|
||||
|
||||
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
|
||||
if not self._is_trained or self.model is None:
|
||||
return None
|
||||
features = self._extract_features(path)
|
||||
try:
|
||||
prediction = self.model.predict([features])[0]
|
||||
return prediction
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def predict_proba(self, path: Path) -> dict[str, float]:
|
||||
if not self._is_trained or self.model is None:
|
||||
return {}
|
||||
features = self._extract_features(path)
|
||||
try:
|
||||
probabilities = self.model.predict_proba([features])[0]
|
||||
return {category: float(prob) for category, prob in zip(self.categories, probabilities)}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def save_model(self, model_path: Path) -> None:
|
||||
if not self._is_trained:
|
||||
raise ValueError('Cannot save untrained model')
|
||||
model_data = {'model': self.model, 'categories': self.categories, 'is_trained': self._is_trained}
|
||||
with open(model_path, 'wb') as f:
|
||||
pickle.dump(model_data, f)
|
||||
|
||||
def load_model(self, model_path: Path) -> None:
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = pickle.load(f)
|
||||
self.model = model_data['model']
|
||||
self.categories = model_data['categories']
|
||||
self._is_trained = model_data['is_trained']
|
||||
|
||||
@property
|
||||
def is_trained(self) -> bool:
|
||||
return self._is_trained
|
||||
|
||||
class DummyMLClassifier:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
||||
raise NotImplementedError('ML classification requires scikit-learn. Install with: pip install scikit-learn')
|
||||
|
||||
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def predict_proba(self, path: Path) -> dict[str, float]:
|
||||
return {}
|
||||
|
||||
def save_model(self, model_path: Path) -> None:
|
||||
raise NotImplementedError('ML classification not available')
|
||||
|
||||
def load_model(self, model_path: Path) -> None:
|
||||
raise NotImplementedError('ML classification not available')
|
||||
|
||||
@property
|
||||
def is_trained(self) -> bool:
|
||||
return False
|
||||
|
||||
def create_ml_classifier() -> MLClassifier | DummyMLClassifier:
|
||||
if SKLEARN_AVAILABLE:
|
||||
return MLClassifier()
|
||||
else:
|
||||
return DummyMLClassifier()
|
||||
|
||||
def train_from_database(db_connection, min_samples_per_category: int=10) -> MLClassifier | DummyMLClassifier:
|
||||
classifier = create_ml_classifier()
|
||||
if isinstance(classifier, DummyMLClassifier):
|
||||
return classifier
|
||||
cursor = db_connection.cursor()
|
||||
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
|
||||
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
||||
cursor.close()
|
||||
if not training_data:
|
||||
return classifier
|
||||
category_counts = {}
|
||||
for _, category in training_data:
|
||||
category_counts[category] = category_counts.get(category, 0) + 1
|
||||
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples_per_category]
|
||||
if filtered_data:
|
||||
classifier.train(filtered_data)
|
||||
return classifier
|
||||
Reference in New Issue
Block a user