149 lines
7.1 KiB
Python
149 lines
7.1 KiB
Python
from pathlib import Path
|
|
from typing import Optional, Callable
|
|
import psycopg2
|
|
from .rules import RuleBasedClassifier
|
|
from .ml import create_ml_classifier, DummyMLClassifier
|
|
from ..shared.models import ProcessingStats
|
|
from ..shared.config import DatabaseConfig
|
|
from ..shared.logger import ProgressLogger
|
|
|
|
class ClassificationEngine:
|
|
|
|
def __init__(self, db_config: DatabaseConfig, logger: ProgressLogger, use_ml: bool=False):
|
|
self.db_config = db_config
|
|
self.logger = logger
|
|
self.rule_classifier = RuleBasedClassifier()
|
|
self.ml_classifier = create_ml_classifier() if use_ml else None
|
|
self.use_ml = use_ml and (not isinstance(self.ml_classifier, DummyMLClassifier))
|
|
self._connection = None
|
|
|
|
def _get_connection(self):
|
|
if self._connection is None or self._connection.closed:
|
|
self._connection = psycopg2.connect(host=self.db_config.host, port=self.db_config.port, database=self.db_config.database, user=self.db_config.user, password=self.db_config.password)
|
|
return self._connection
|
|
|
|
def classify_all(self, disk: Optional[str]=None, batch_size: int=1000, progress_callback: Optional[Callable[[int, int, ProcessingStats], None]]=None) -> ProcessingStats:
|
|
self.logger.section('Starting Classification')
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
if disk:
|
|
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE disk_label = %s AND category IS NULL\n ', (disk,))
|
|
else:
|
|
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE category IS NULL\n ')
|
|
files_to_classify = cursor.fetchall()
|
|
total_files = len(files_to_classify)
|
|
self.logger.info(f'Found {total_files} files to classify')
|
|
stats = ProcessingStats()
|
|
batch = []
|
|
for path_str, checksum in files_to_classify:
|
|
path = Path(path_str)
|
|
category = self.rule_classifier.classify(path)
|
|
if category is None and self.use_ml and self.ml_classifier:
|
|
category = self.ml_classifier.classify(path)
|
|
if category is None:
|
|
category = 'temp/processing'
|
|
batch.append((category, str(path)))
|
|
stats.files_processed += 1
|
|
if len(batch) >= batch_size:
|
|
self._update_categories(cursor, batch)
|
|
conn.commit()
|
|
batch.clear()
|
|
if progress_callback:
|
|
progress_callback(stats.files_processed, total_files, stats)
|
|
if stats.files_processed % (batch_size * 10) == 0:
|
|
self.logger.progress(stats.files_processed, total_files, prefix='Files classified', elapsed_seconds=stats.elapsed_seconds)
|
|
if batch:
|
|
self._update_categories(cursor, batch)
|
|
conn.commit()
|
|
stats.files_succeeded = stats.files_processed
|
|
cursor.close()
|
|
self.logger.info(f'Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s')
|
|
return stats
|
|
|
|
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
|
|
from psycopg2.extras import execute_batch
|
|
query = '\n UPDATE files\n SET category = %s\n WHERE path = %s\n '
|
|
execute_batch(cursor, query, batch)
|
|
|
|
def classify_path(self, path: Path) -> Optional[str]:
|
|
category = self.rule_classifier.classify(path)
|
|
if category is None and self.use_ml and self.ml_classifier:
|
|
category = self.ml_classifier.classify(path)
|
|
return category
|
|
|
|
def get_category_stats(self) -> dict[str, dict]:
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute('\n SELECT\n category,\n COUNT(*) as file_count,\n SUM(size) as total_size\n FROM files\n WHERE category IS NOT NULL\n GROUP BY category\n ORDER BY total_size DESC\n ')
|
|
stats = {}
|
|
for category, file_count, total_size in cursor.fetchall():
|
|
stats[category] = {'file_count': file_count, 'total_size': total_size}
|
|
cursor.close()
|
|
return stats
|
|
|
|
def get_uncategorized_count(self) -> int:
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute('SELECT COUNT(*) FROM files WHERE category IS NULL')
|
|
count = cursor.fetchone()[0]
|
|
cursor.close()
|
|
return count
|
|
|
|
def reclassify_category(self, old_category: str, new_category: str) -> int:
|
|
self.logger.info(f'Reclassifying {old_category} -> {new_category}')
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute('\n UPDATE files\n SET category = %s\n WHERE category = %s\n ', (new_category, old_category))
|
|
count = cursor.rowcount
|
|
conn.commit()
|
|
cursor.close()
|
|
self.logger.info(f'Reclassified {count} files')
|
|
return count
|
|
|
|
def train_ml_classifier(self, min_samples: int=10) -> bool:
|
|
if not self.use_ml or self.ml_classifier is None:
|
|
self.logger.warning('ML classifier not available')
|
|
return False
|
|
self.logger.subsection('Training ML Classifier')
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
|
|
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
|
cursor.close()
|
|
if not training_data:
|
|
self.logger.warning('No training data available')
|
|
return False
|
|
category_counts = {}
|
|
for _, category in training_data:
|
|
category_counts[category] = category_counts.get(category, 0) + 1
|
|
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples]
|
|
if not filtered_data:
|
|
self.logger.warning(f'No categories with >= {min_samples} samples')
|
|
return False
|
|
self.logger.info(f'Training with {len(filtered_data)} samples')
|
|
try:
|
|
self.ml_classifier.train(filtered_data)
|
|
self.logger.info('ML classifier trained successfully')
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f'Failed to train ML classifier: {e}')
|
|
return False
|
|
|
|
def get_all_categories(self) -> list[str]:
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute('\n SELECT DISTINCT category\n FROM files\n WHERE category IS NOT NULL\n ORDER BY category\n ')
|
|
categories = [row[0] for row in cursor.fetchall()]
|
|
cursor.close()
|
|
return categories
|
|
|
|
def close(self):
|
|
if self._connection and (not self._connection.closed):
|
|
self._connection.close()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|