351 lines
9.9 KiB
Python
351 lines
9.9 KiB
Python
"""Main classification engine"""
|
|
from pathlib import Path
|
|
from typing import Optional, Callable
|
|
import psycopg2
|
|
|
|
from .rules import RuleBasedClassifier
|
|
from .ml import create_ml_classifier, DummyMLClassifier
|
|
from ..shared.models import ProcessingStats
|
|
from ..shared.config import DatabaseConfig
|
|
from ..shared.logger import ProgressLogger
|
|
|
|
|
|
class ClassificationEngine:
|
|
"""Engine for classifying files"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_config: DatabaseConfig,
|
|
logger: ProgressLogger,
|
|
use_ml: bool = False
|
|
):
|
|
"""Initialize classification engine
|
|
|
|
Args:
|
|
db_config: Database configuration
|
|
logger: Progress logger
|
|
use_ml: Whether to use ML classification in addition to rules
|
|
"""
|
|
self.db_config = db_config
|
|
self.logger = logger
|
|
self.rule_classifier = RuleBasedClassifier()
|
|
self.ml_classifier = create_ml_classifier() if use_ml else None
|
|
self.use_ml = use_ml and not isinstance(self.ml_classifier, DummyMLClassifier)
|
|
self._connection = None
|
|
|
|
def _get_connection(self):
|
|
"""Get or create database connection"""
|
|
if self._connection is None or self._connection.closed:
|
|
self._connection = psycopg2.connect(
|
|
host=self.db_config.host,
|
|
port=self.db_config.port,
|
|
database=self.db_config.database,
|
|
user=self.db_config.user,
|
|
password=self.db_config.password
|
|
)
|
|
return self._connection
|
|
|
|
def classify_all(
|
|
self,
|
|
disk: Optional[str] = None,
|
|
batch_size: int = 1000,
|
|
progress_callback: Optional[Callable[[int, int, ProcessingStats], None]] = None
|
|
) -> ProcessingStats:
|
|
"""Classify all files in database
|
|
|
|
Args:
|
|
disk: Optional disk filter
|
|
batch_size: Number of files to process per batch
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
ProcessingStats with classification statistics
|
|
"""
|
|
self.logger.section("Starting Classification")
|
|
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# Get files without categories
|
|
if disk:
|
|
cursor.execute("""
|
|
SELECT path, checksum
|
|
FROM files
|
|
WHERE disk_label = %s AND category IS NULL
|
|
""", (disk,))
|
|
else:
|
|
cursor.execute("""
|
|
SELECT path, checksum
|
|
FROM files
|
|
WHERE category IS NULL
|
|
""")
|
|
|
|
files_to_classify = cursor.fetchall()
|
|
total_files = len(files_to_classify)
|
|
|
|
self.logger.info(f"Found {total_files} files to classify")
|
|
|
|
stats = ProcessingStats()
|
|
batch = []
|
|
|
|
for path_str, checksum in files_to_classify:
|
|
path = Path(path_str)
|
|
|
|
# Classify using rules first
|
|
category = self.rule_classifier.classify(path)
|
|
|
|
# If no rule match and ML is available, try ML
|
|
if category is None and self.use_ml and self.ml_classifier:
|
|
category = self.ml_classifier.classify(path)
|
|
|
|
# If still no category, assign default
|
|
if category is None:
|
|
category = "temp/processing"
|
|
|
|
batch.append((category, str(path)))
|
|
stats.files_processed += 1
|
|
|
|
# Batch update
|
|
if len(batch) >= batch_size:
|
|
self._update_categories(cursor, batch)
|
|
conn.commit()
|
|
batch.clear()
|
|
|
|
# Progress callback
|
|
if progress_callback:
|
|
progress_callback(stats.files_processed, total_files, stats)
|
|
|
|
# Log progress
|
|
if stats.files_processed % (batch_size * 10) == 0:
|
|
self.logger.progress(
|
|
stats.files_processed,
|
|
total_files,
|
|
prefix="Files classified",
|
|
elapsed_seconds=stats.elapsed_seconds
|
|
)
|
|
|
|
# Update remaining batch
|
|
if batch:
|
|
self._update_categories(cursor, batch)
|
|
conn.commit()
|
|
|
|
stats.files_succeeded = stats.files_processed
|
|
|
|
cursor.close()
|
|
|
|
self.logger.info(
|
|
f"Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s"
|
|
)
|
|
|
|
return stats
|
|
|
|
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
|
|
"""Update categories in batch
|
|
|
|
Args:
|
|
cursor: Database cursor
|
|
batch: List of (category, path) tuples
|
|
"""
|
|
from psycopg2.extras import execute_batch
|
|
|
|
query = """
|
|
UPDATE files
|
|
SET category = %s
|
|
WHERE path = %s
|
|
"""
|
|
|
|
execute_batch(cursor, query, batch)
|
|
|
|
def classify_path(self, path: Path) -> Optional[str]:
|
|
"""Classify a single path
|
|
|
|
Args:
|
|
path: Path to classify
|
|
|
|
Returns:
|
|
Category name or None
|
|
"""
|
|
# Try rules first
|
|
category = self.rule_classifier.classify(path)
|
|
|
|
# Try ML if available
|
|
if category is None and self.use_ml and self.ml_classifier:
|
|
category = self.ml_classifier.classify(path)
|
|
|
|
return category
|
|
|
|
def get_category_stats(self) -> dict[str, dict]:
|
|
"""Get statistics by category
|
|
|
|
Returns:
|
|
Dictionary mapping category to statistics
|
|
"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT
|
|
category,
|
|
COUNT(*) as file_count,
|
|
SUM(size) as total_size
|
|
FROM files
|
|
WHERE category IS NOT NULL
|
|
GROUP BY category
|
|
ORDER BY total_size DESC
|
|
""")
|
|
|
|
stats = {}
|
|
for category, file_count, total_size in cursor.fetchall():
|
|
stats[category] = {
|
|
'file_count': file_count,
|
|
'total_size': total_size
|
|
}
|
|
|
|
cursor.close()
|
|
|
|
return stats
|
|
|
|
def get_uncategorized_count(self) -> int:
|
|
"""Get count of uncategorized files
|
|
|
|
Returns:
|
|
Number of files without category
|
|
"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM files WHERE category IS NULL")
|
|
count = cursor.fetchone()[0]
|
|
|
|
cursor.close()
|
|
|
|
return count
|
|
|
|
def reclassify_category(
|
|
self,
|
|
old_category: str,
|
|
new_category: str
|
|
) -> int:
|
|
"""Reclassify all files in a category
|
|
|
|
Args:
|
|
old_category: Current category
|
|
new_category: New category
|
|
|
|
Returns:
|
|
Number of files reclassified
|
|
"""
|
|
self.logger.info(f"Reclassifying {old_category} -> {new_category}")
|
|
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
UPDATE files
|
|
SET category = %s
|
|
WHERE category = %s
|
|
""", (new_category, old_category))
|
|
|
|
count = cursor.rowcount
|
|
conn.commit()
|
|
cursor.close()
|
|
|
|
self.logger.info(f"Reclassified {count} files")
|
|
|
|
return count
|
|
|
|
def train_ml_classifier(
|
|
self,
|
|
min_samples: int = 10
|
|
) -> bool:
|
|
"""Train ML classifier from existing categorized data
|
|
|
|
Args:
|
|
min_samples: Minimum samples per category
|
|
|
|
Returns:
|
|
True if training successful
|
|
"""
|
|
if not self.use_ml or self.ml_classifier is None:
|
|
self.logger.warning("ML classifier not available")
|
|
return False
|
|
|
|
self.logger.subsection("Training ML Classifier")
|
|
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# Get categorized files
|
|
cursor.execute("""
|
|
SELECT path, category
|
|
FROM files
|
|
WHERE category IS NOT NULL
|
|
""")
|
|
|
|
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
|
cursor.close()
|
|
|
|
if not training_data:
|
|
self.logger.warning("No training data available")
|
|
return False
|
|
|
|
# Count samples per category
|
|
category_counts = {}
|
|
for _, category in training_data:
|
|
category_counts[category] = category_counts.get(category, 0) + 1
|
|
|
|
# Filter categories with enough samples
|
|
filtered_data = [
|
|
(path, category)
|
|
for path, category in training_data
|
|
if category_counts[category] >= min_samples
|
|
]
|
|
|
|
if not filtered_data:
|
|
self.logger.warning(f"No categories with >= {min_samples} samples")
|
|
return False
|
|
|
|
self.logger.info(f"Training with {len(filtered_data)} samples")
|
|
|
|
try:
|
|
self.ml_classifier.train(filtered_data)
|
|
self.logger.info("ML classifier trained successfully")
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to train ML classifier: {e}")
|
|
return False
|
|
|
|
def get_all_categories(self) -> list[str]:
|
|
"""Get all categories from database
|
|
|
|
Returns:
|
|
List of category names
|
|
"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT DISTINCT category
|
|
FROM files
|
|
WHERE category IS NOT NULL
|
|
ORDER BY category
|
|
""")
|
|
|
|
categories = [row[0] for row in cursor.fetchall()]
|
|
cursor.close()
|
|
|
|
return categories
|
|
|
|
def close(self):
|
|
"""Close database connection"""
|
|
if self._connection and not self._connection.closed:
|
|
self._connection.close()
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry"""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit"""
|
|
self.close()
|