Files
defrag/app/classification/engine.py
2025-12-13 11:56:06 +01:00

149 lines
7.1 KiB
Python

from pathlib import Path
from typing import Optional, Callable
import psycopg2
from .rules import RuleBasedClassifier
from .ml import create_ml_classifier, DummyMLClassifier
from ..shared.models import ProcessingStats
from ..shared.config import DatabaseConfig
from ..shared.logger import ProgressLogger
class ClassificationEngine:
def __init__(self, db_config: DatabaseConfig, logger: ProgressLogger, use_ml: bool=False):
self.db_config = db_config
self.logger = logger
self.rule_classifier = RuleBasedClassifier()
self.ml_classifier = create_ml_classifier() if use_ml else None
self.use_ml = use_ml and (not isinstance(self.ml_classifier, DummyMLClassifier))
self._connection = None
def _get_connection(self):
if self._connection is None or self._connection.closed:
self._connection = psycopg2.connect(host=self.db_config.host, port=self.db_config.port, database=self.db_config.database, user=self.db_config.user, password=self.db_config.password)
return self._connection
def classify_all(self, disk: Optional[str]=None, batch_size: int=1000, progress_callback: Optional[Callable[[int, int, ProcessingStats], None]]=None) -> ProcessingStats:
self.logger.section('Starting Classification')
conn = self._get_connection()
cursor = conn.cursor()
if disk:
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE disk_label = %s AND category IS NULL\n ', (disk,))
else:
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE category IS NULL\n ')
files_to_classify = cursor.fetchall()
total_files = len(files_to_classify)
self.logger.info(f'Found {total_files} files to classify')
stats = ProcessingStats()
batch = []
for path_str, checksum in files_to_classify:
path = Path(path_str)
category = self.rule_classifier.classify(path)
if category is None and self.use_ml and self.ml_classifier:
category = self.ml_classifier.classify(path)
if category is None:
category = 'temp/processing'
batch.append((category, str(path)))
stats.files_processed += 1
if len(batch) >= batch_size:
self._update_categories(cursor, batch)
conn.commit()
batch.clear()
if progress_callback:
progress_callback(stats.files_processed, total_files, stats)
if stats.files_processed % (batch_size * 10) == 0:
self.logger.progress(stats.files_processed, total_files, prefix='Files classified', elapsed_seconds=stats.elapsed_seconds)
if batch:
self._update_categories(cursor, batch)
conn.commit()
stats.files_succeeded = stats.files_processed
cursor.close()
self.logger.info(f'Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s')
return stats
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
from psycopg2.extras import execute_batch
query = '\n UPDATE files\n SET category = %s\n WHERE path = %s\n '
execute_batch(cursor, query, batch)
def classify_path(self, path: Path) -> Optional[str]:
category = self.rule_classifier.classify(path)
if category is None and self.use_ml and self.ml_classifier:
category = self.ml_classifier.classify(path)
return category
def get_category_stats(self) -> dict[str, dict]:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('\n SELECT\n category,\n COUNT(*) as file_count,\n SUM(size) as total_size\n FROM files\n WHERE category IS NOT NULL\n GROUP BY category\n ORDER BY total_size DESC\n ')
stats = {}
for category, file_count, total_size in cursor.fetchall():
stats[category] = {'file_count': file_count, 'total_size': total_size}
cursor.close()
return stats
def get_uncategorized_count(self) -> int:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM files WHERE category IS NULL')
count = cursor.fetchone()[0]
cursor.close()
return count
def reclassify_category(self, old_category: str, new_category: str) -> int:
self.logger.info(f'Reclassifying {old_category} -> {new_category}')
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('\n UPDATE files\n SET category = %s\n WHERE category = %s\n ', (new_category, old_category))
count = cursor.rowcount
conn.commit()
cursor.close()
self.logger.info(f'Reclassified {count} files')
return count
def train_ml_classifier(self, min_samples: int=10) -> bool:
if not self.use_ml or self.ml_classifier is None:
self.logger.warning('ML classifier not available')
return False
self.logger.subsection('Training ML Classifier')
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
cursor.close()
if not training_data:
self.logger.warning('No training data available')
return False
category_counts = {}
for _, category in training_data:
category_counts[category] = category_counts.get(category, 0) + 1
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples]
if not filtered_data:
self.logger.warning(f'No categories with >= {min_samples} samples')
return False
self.logger.info(f'Training with {len(filtered_data)} samples')
try:
self.ml_classifier.train(filtered_data)
self.logger.info('ML classifier trained successfully')
return True
except Exception as e:
self.logger.error(f'Failed to train ML classifier: {e}')
return False
def get_all_categories(self) -> list[str]:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('\n SELECT DISTINCT category\n FROM files\n WHERE category IS NOT NULL\n ORDER BY category\n ')
categories = [row[0] for row in cursor.fetchall()]
cursor.close()
return categories
def close(self):
if self._connection and (not self._connection.closed):
self._connection.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()