"""Main classification engine""" from pathlib import Path from typing import Optional, Callable import psycopg2 from .rules import RuleBasedClassifier from .ml import create_ml_classifier, DummyMLClassifier from ..shared.models import ProcessingStats from ..shared.config import DatabaseConfig from ..shared.logger import ProgressLogger class ClassificationEngine: """Engine for classifying files""" def __init__( self, db_config: DatabaseConfig, logger: ProgressLogger, use_ml: bool = False ): """Initialize classification engine Args: db_config: Database configuration logger: Progress logger use_ml: Whether to use ML classification in addition to rules """ self.db_config = db_config self.logger = logger self.rule_classifier = RuleBasedClassifier() self.ml_classifier = create_ml_classifier() if use_ml else None self.use_ml = use_ml and not isinstance(self.ml_classifier, DummyMLClassifier) self._connection = None def _get_connection(self): """Get or create database connection""" if self._connection is None or self._connection.closed: self._connection = psycopg2.connect( host=self.db_config.host, port=self.db_config.port, database=self.db_config.database, user=self.db_config.user, password=self.db_config.password ) return self._connection def classify_all( self, disk: Optional[str] = None, batch_size: int = 1000, progress_callback: Optional[Callable[[int, int, ProcessingStats], None]] = None ) -> ProcessingStats: """Classify all files in database Args: disk: Optional disk filter batch_size: Number of files to process per batch progress_callback: Optional callback for progress updates Returns: ProcessingStats with classification statistics """ self.logger.section("Starting Classification") conn = self._get_connection() cursor = conn.cursor() # Get files without categories if disk: cursor.execute(""" SELECT path, checksum FROM files_bak WHERE disk = %s AND category IS NULL """, (disk,)) else: cursor.execute(""" SELECT path, checksum FROM files_bak WHERE category IS NULL """) files_to_classify = cursor.fetchall() total_files = len(files_to_classify) self.logger.info(f"Found {total_files} files to classify") stats = ProcessingStats() batch = [] for path_str, checksum in files_to_classify: path = Path(path_str) # Classify using rules first category = self.rule_classifier.classify(path) # If no rule match and ML is available, try ML if category is None and self.use_ml and self.ml_classifier: category = self.ml_classifier.classify(path) # If still no category, assign default if category is None: category = "temp/processing" batch.append((category, str(path))) stats.files_processed += 1 # Batch update if len(batch) >= batch_size: self._update_categories(cursor, batch) conn.commit() batch.clear() # Progress callback if progress_callback: progress_callback(stats.files_processed, total_files, stats) # Log progress if stats.files_processed % (batch_size * 10) == 0: self.logger.progress( stats.files_processed, total_files, prefix="Files classified", elapsed_seconds=stats.elapsed_seconds ) # Update remaining batch if batch: self._update_categories(cursor, batch) conn.commit() stats.files_succeeded = stats.files_processed cursor.close() self.logger.info( f"Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s" ) return stats def _update_categories(self, cursor, batch: list[tuple[str, str]]): """Update categories in batch Args: cursor: Database cursor batch: List of (category, path) tuples """ from psycopg2.extras import execute_batch query = """ UPDATE files_bak SET category = %s WHERE path = %s """ execute_batch(cursor, query, batch) def classify_path(self, path: Path) -> Optional[str]: """Classify a single path Args: path: Path to classify Returns: Category name or None """ # Try rules first category = self.rule_classifier.classify(path) # Try ML if available if category is None and self.use_ml and self.ml_classifier: category = self.ml_classifier.classify(path) return category def get_category_stats(self) -> dict[str, dict]: """Get statistics by category Returns: Dictionary mapping category to statistics """ conn = self._get_connection() cursor = conn.cursor() cursor.execute(""" SELECT category, COUNT(*) as file_count, SUM(size) as total_size FROM files_bak WHERE category IS NOT NULL GROUP BY category ORDER BY total_size DESC """) stats = {} for category, file_count, total_size in cursor.fetchall(): stats[category] = { 'file_count': file_count, 'total_size': total_size } cursor.close() return stats def get_uncategorized_count(self) -> int: """Get count of uncategorized files Returns: Number of files without category """ conn = self._get_connection() cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM files_bak WHERE category IS NULL") count = cursor.fetchone()[0] cursor.close() return count def reclassify_category( self, old_category: str, new_category: str ) -> int: """Reclassify all files in a category Args: old_category: Current category new_category: New category Returns: Number of files reclassified """ self.logger.info(f"Reclassifying {old_category} -> {new_category}") conn = self._get_connection() cursor = conn.cursor() cursor.execute(""" UPDATE files_bak SET category = %s WHERE category = %s """, (new_category, old_category)) count = cursor.rowcount conn.commit() cursor.close() self.logger.info(f"Reclassified {count} files") return count def train_ml_classifier( self, min_samples: int = 10 ) -> bool: """Train ML classifier from existing categorized data Args: min_samples: Minimum samples per category Returns: True if training successful """ if not self.use_ml or self.ml_classifier is None: self.logger.warning("ML classifier not available") return False self.logger.subsection("Training ML Classifier") conn = self._get_connection() cursor = conn.cursor() # Get categorized files cursor.execute(""" SELECT path, category FROM files_bak WHERE category IS NOT NULL """) training_data = [(Path(path), category) for path, category in cursor.fetchall()] cursor.close() if not training_data: self.logger.warning("No training data available") return False # Count samples per category category_counts = {} for _, category in training_data: category_counts[category] = category_counts.get(category, 0) + 1 # Filter categories with enough samples filtered_data = [ (path, category) for path, category in training_data if category_counts[category] >= min_samples ] if not filtered_data: self.logger.warning(f"No categories with >= {min_samples} samples") return False self.logger.info(f"Training with {len(filtered_data)} samples") try: self.ml_classifier.train(filtered_data) self.logger.info("ML classifier trained successfully") return True except Exception as e: self.logger.error(f"Failed to train ML classifier: {e}") return False def get_all_categories(self) -> list[str]: """Get all categories from database Returns: List of category names """ conn = self._get_connection() cursor = conn.cursor() cursor.execute(""" SELECT DISTINCT category FROM files_bak WHERE category IS NOT NULL ORDER BY category """) categories = [row[0] for row in cursor.fetchall()] cursor.close() return categories def close(self): """Close database connection""" if self._connection and not self._connection.closed: self._connection.close() def __enter__(self): """Context manager entry""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit""" self.close()