Files
defrag/app/classification/engine.py
2025-12-12 23:04:51 +01:00

351 lines
9.9 KiB
Python

"""Main classification engine"""
from pathlib import Path
from typing import Optional, Callable
import psycopg2
from .rules import RuleBasedClassifier
from .ml import create_ml_classifier, DummyMLClassifier
from ..shared.models import ProcessingStats
from ..shared.config import DatabaseConfig
from ..shared.logger import ProgressLogger
class ClassificationEngine:
"""Engine for classifying files"""
def __init__(
self,
db_config: DatabaseConfig,
logger: ProgressLogger,
use_ml: bool = False
):
"""Initialize classification engine
Args:
db_config: Database configuration
logger: Progress logger
use_ml: Whether to use ML classification in addition to rules
"""
self.db_config = db_config
self.logger = logger
self.rule_classifier = RuleBasedClassifier()
self.ml_classifier = create_ml_classifier() if use_ml else None
self.use_ml = use_ml and not isinstance(self.ml_classifier, DummyMLClassifier)
self._connection = None
def _get_connection(self):
"""Get or create database connection"""
if self._connection is None or self._connection.closed:
self._connection = psycopg2.connect(
host=self.db_config.host,
port=self.db_config.port,
database=self.db_config.database,
user=self.db_config.user,
password=self.db_config.password
)
return self._connection
def classify_all(
self,
disk: Optional[str] = None,
batch_size: int = 1000,
progress_callback: Optional[Callable[[int, int, ProcessingStats], None]] = None
) -> ProcessingStats:
"""Classify all files in database
Args:
disk: Optional disk filter
batch_size: Number of files to process per batch
progress_callback: Optional callback for progress updates
Returns:
ProcessingStats with classification statistics
"""
self.logger.section("Starting Classification")
conn = self._get_connection()
cursor = conn.cursor()
# Get files without categories
if disk:
cursor.execute("""
SELECT path, checksum
FROM files
WHERE disk_label = %s AND category IS NULL
""", (disk,))
else:
cursor.execute("""
SELECT path, checksum
FROM files
WHERE category IS NULL
""")
files_to_classify = cursor.fetchall()
total_files = len(files_to_classify)
self.logger.info(f"Found {total_files} files to classify")
stats = ProcessingStats()
batch = []
for path_str, checksum in files_to_classify:
path = Path(path_str)
# Classify using rules first
category = self.rule_classifier.classify(path)
# If no rule match and ML is available, try ML
if category is None and self.use_ml and self.ml_classifier:
category = self.ml_classifier.classify(path)
# If still no category, assign default
if category is None:
category = "temp/processing"
batch.append((category, str(path)))
stats.files_processed += 1
# Batch update
if len(batch) >= batch_size:
self._update_categories(cursor, batch)
conn.commit()
batch.clear()
# Progress callback
if progress_callback:
progress_callback(stats.files_processed, total_files, stats)
# Log progress
if stats.files_processed % (batch_size * 10) == 0:
self.logger.progress(
stats.files_processed,
total_files,
prefix="Files classified",
elapsed_seconds=stats.elapsed_seconds
)
# Update remaining batch
if batch:
self._update_categories(cursor, batch)
conn.commit()
stats.files_succeeded = stats.files_processed
cursor.close()
self.logger.info(
f"Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s"
)
return stats
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
"""Update categories in batch
Args:
cursor: Database cursor
batch: List of (category, path) tuples
"""
from psycopg2.extras import execute_batch
query = """
UPDATE files
SET category = %s
WHERE path = %s
"""
execute_batch(cursor, query, batch)
def classify_path(self, path: Path) -> Optional[str]:
"""Classify a single path
Args:
path: Path to classify
Returns:
Category name or None
"""
# Try rules first
category = self.rule_classifier.classify(path)
# Try ML if available
if category is None and self.use_ml and self.ml_classifier:
category = self.ml_classifier.classify(path)
return category
def get_category_stats(self) -> dict[str, dict]:
"""Get statistics by category
Returns:
Dictionary mapping category to statistics
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT
category,
COUNT(*) as file_count,
SUM(size) as total_size
FROM files
WHERE category IS NOT NULL
GROUP BY category
ORDER BY total_size DESC
""")
stats = {}
for category, file_count, total_size in cursor.fetchall():
stats[category] = {
'file_count': file_count,
'total_size': total_size
}
cursor.close()
return stats
def get_uncategorized_count(self) -> int:
"""Get count of uncategorized files
Returns:
Number of files without category
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM files WHERE category IS NULL")
count = cursor.fetchone()[0]
cursor.close()
return count
def reclassify_category(
self,
old_category: str,
new_category: str
) -> int:
"""Reclassify all files in a category
Args:
old_category: Current category
new_category: New category
Returns:
Number of files reclassified
"""
self.logger.info(f"Reclassifying {old_category} -> {new_category}")
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
UPDATE files
SET category = %s
WHERE category = %s
""", (new_category, old_category))
count = cursor.rowcount
conn.commit()
cursor.close()
self.logger.info(f"Reclassified {count} files")
return count
def train_ml_classifier(
self,
min_samples: int = 10
) -> bool:
"""Train ML classifier from existing categorized data
Args:
min_samples: Minimum samples per category
Returns:
True if training successful
"""
if not self.use_ml or self.ml_classifier is None:
self.logger.warning("ML classifier not available")
return False
self.logger.subsection("Training ML Classifier")
conn = self._get_connection()
cursor = conn.cursor()
# Get categorized files
cursor.execute("""
SELECT path, category
FROM files
WHERE category IS NOT NULL
""")
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
cursor.close()
if not training_data:
self.logger.warning("No training data available")
return False
# Count samples per category
category_counts = {}
for _, category in training_data:
category_counts[category] = category_counts.get(category, 0) + 1
# Filter categories with enough samples
filtered_data = [
(path, category)
for path, category in training_data
if category_counts[category] >= min_samples
]
if not filtered_data:
self.logger.warning(f"No categories with >= {min_samples} samples")
return False
self.logger.info(f"Training with {len(filtered_data)} samples")
try:
self.ml_classifier.train(filtered_data)
self.logger.info("ML classifier trained successfully")
return True
except Exception as e:
self.logger.error(f"Failed to train ML classifier: {e}")
return False
def get_all_categories(self) -> list[str]:
"""Get all categories from database
Returns:
List of category names
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT DISTINCT category
FROM files
WHERE category IS NOT NULL
ORDER BY category
""")
categories = [row[0] for row in cursor.fetchall()]
cursor.close()
return categories
def close(self):
"""Close database connection"""
if self._connection and not self._connection.closed:
self._connection.close()
def __enter__(self):
"""Context manager entry"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
self.close()