from pathlib import Path from typing import Optional, Callable import psycopg2 from .rules import RuleBasedClassifier from .ml import create_ml_classifier, DummyMLClassifier from ..shared.models import ProcessingStats from ..shared.config import DatabaseConfig from ..shared.logger import ProgressLogger class ClassificationEngine: def __init__(self, db_config: DatabaseConfig, logger: ProgressLogger, use_ml: bool=False): self.db_config = db_config self.logger = logger self.rule_classifier = RuleBasedClassifier() self.ml_classifier = create_ml_classifier() if use_ml else None self.use_ml = use_ml and (not isinstance(self.ml_classifier, DummyMLClassifier)) self._connection = None def _get_connection(self): if self._connection is None or self._connection.closed: self._connection = psycopg2.connect(host=self.db_config.host, port=self.db_config.port, database=self.db_config.database, user=self.db_config.user, password=self.db_config.password) return self._connection def classify_all(self, disk: Optional[str]=None, batch_size: int=1000, progress_callback: Optional[Callable[[int, int, ProcessingStats], None]]=None) -> ProcessingStats: self.logger.section('Starting Classification') conn = self._get_connection() cursor = conn.cursor() if disk: cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE disk_label = %s AND category IS NULL\n ', (disk,)) else: cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE category IS NULL\n ') files_to_classify = cursor.fetchall() total_files = len(files_to_classify) self.logger.info(f'Found {total_files} files to classify') stats = ProcessingStats() batch = [] for path_str, checksum in files_to_classify: path = Path(path_str) category = self.rule_classifier.classify(path) if category is None and self.use_ml and self.ml_classifier: category = self.ml_classifier.classify(path) if category is None: category = 'temp/processing' batch.append((category, str(path))) stats.files_processed += 1 if len(batch) >= batch_size: self._update_categories(cursor, batch) conn.commit() batch.clear() if progress_callback: progress_callback(stats.files_processed, total_files, stats) if stats.files_processed % (batch_size * 10) == 0: self.logger.progress(stats.files_processed, total_files, prefix='Files classified', elapsed_seconds=stats.elapsed_seconds) if batch: self._update_categories(cursor, batch) conn.commit() stats.files_succeeded = stats.files_processed cursor.close() self.logger.info(f'Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s') return stats def _update_categories(self, cursor, batch: list[tuple[str, str]]): from psycopg2.extras import execute_batch query = '\n UPDATE files\n SET category = %s\n WHERE path = %s\n ' execute_batch(cursor, query, batch) def classify_path(self, path: Path) -> Optional[str]: category = self.rule_classifier.classify(path) if category is None and self.use_ml and self.ml_classifier: category = self.ml_classifier.classify(path) return category def get_category_stats(self) -> dict[str, dict]: conn = self._get_connection() cursor = conn.cursor() cursor.execute('\n SELECT\n category,\n COUNT(*) as file_count,\n SUM(size) as total_size\n FROM files\n WHERE category IS NOT NULL\n GROUP BY category\n ORDER BY total_size DESC\n ') stats = {} for category, file_count, total_size in cursor.fetchall(): stats[category] = {'file_count': file_count, 'total_size': total_size} cursor.close() return stats def get_uncategorized_count(self) -> int: conn = self._get_connection() cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM files WHERE category IS NULL') count = cursor.fetchone()[0] cursor.close() return count def reclassify_category(self, old_category: str, new_category: str) -> int: self.logger.info(f'Reclassifying {old_category} -> {new_category}') conn = self._get_connection() cursor = conn.cursor() cursor.execute('\n UPDATE files\n SET category = %s\n WHERE category = %s\n ', (new_category, old_category)) count = cursor.rowcount conn.commit() cursor.close() self.logger.info(f'Reclassified {count} files') return count def train_ml_classifier(self, min_samples: int=10) -> bool: if not self.use_ml or self.ml_classifier is None: self.logger.warning('ML classifier not available') return False self.logger.subsection('Training ML Classifier') conn = self._get_connection() cursor = conn.cursor() cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ') training_data = [(Path(path), category) for path, category in cursor.fetchall()] cursor.close() if not training_data: self.logger.warning('No training data available') return False category_counts = {} for _, category in training_data: category_counts[category] = category_counts.get(category, 0) + 1 filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples] if not filtered_data: self.logger.warning(f'No categories with >= {min_samples} samples') return False self.logger.info(f'Training with {len(filtered_data)} samples') try: self.ml_classifier.train(filtered_data) self.logger.info('ML classifier trained successfully') return True except Exception as e: self.logger.error(f'Failed to train ML classifier: {e}') return False def get_all_categories(self) -> list[str]: conn = self._get_connection() cursor = conn.cursor() cursor.execute('\n SELECT DISTINCT category\n FROM files\n WHERE category IS NOT NULL\n ORDER BY category\n ') categories = [row[0] for row in cursor.fetchall()] cursor.close() return categories def close(self): if self._connection and (not self._connection.closed): self._connection.close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()