fly wa
This commit is contained in:
@@ -1,350 +1,148 @@
|
||||
"""Main classification engine"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable
|
||||
import psycopg2
|
||||
|
||||
from .rules import RuleBasedClassifier
|
||||
from .ml import create_ml_classifier, DummyMLClassifier
|
||||
from ..shared.models import ProcessingStats
|
||||
from ..shared.config import DatabaseConfig
|
||||
from ..shared.logger import ProgressLogger
|
||||
|
||||
|
||||
class ClassificationEngine:
|
||||
"""Engine for classifying files"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_config: DatabaseConfig,
|
||||
logger: ProgressLogger,
|
||||
use_ml: bool = False
|
||||
):
|
||||
"""Initialize classification engine
|
||||
|
||||
Args:
|
||||
db_config: Database configuration
|
||||
logger: Progress logger
|
||||
use_ml: Whether to use ML classification in addition to rules
|
||||
"""
|
||||
def __init__(self, db_config: DatabaseConfig, logger: ProgressLogger, use_ml: bool=False):
|
||||
self.db_config = db_config
|
||||
self.logger = logger
|
||||
self.rule_classifier = RuleBasedClassifier()
|
||||
self.ml_classifier = create_ml_classifier() if use_ml else None
|
||||
self.use_ml = use_ml and not isinstance(self.ml_classifier, DummyMLClassifier)
|
||||
self.use_ml = use_ml and (not isinstance(self.ml_classifier, DummyMLClassifier))
|
||||
self._connection = None
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get or create database connection"""
|
||||
if self._connection is None or self._connection.closed:
|
||||
self._connection = psycopg2.connect(
|
||||
host=self.db_config.host,
|
||||
port=self.db_config.port,
|
||||
database=self.db_config.database,
|
||||
user=self.db_config.user,
|
||||
password=self.db_config.password
|
||||
)
|
||||
self._connection = psycopg2.connect(host=self.db_config.host, port=self.db_config.port, database=self.db_config.database, user=self.db_config.user, password=self.db_config.password)
|
||||
return self._connection
|
||||
|
||||
def classify_all(
|
||||
self,
|
||||
disk: Optional[str] = None,
|
||||
batch_size: int = 1000,
|
||||
progress_callback: Optional[Callable[[int, int, ProcessingStats], None]] = None
|
||||
) -> ProcessingStats:
|
||||
"""Classify all files in database
|
||||
|
||||
Args:
|
||||
disk: Optional disk filter
|
||||
batch_size: Number of files to process per batch
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
ProcessingStats with classification statistics
|
||||
"""
|
||||
self.logger.section("Starting Classification")
|
||||
|
||||
def classify_all(self, disk: Optional[str]=None, batch_size: int=1000, progress_callback: Optional[Callable[[int, int, ProcessingStats], None]]=None) -> ProcessingStats:
|
||||
self.logger.section('Starting Classification')
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get files without categories
|
||||
if disk:
|
||||
cursor.execute("""
|
||||
SELECT path, checksum
|
||||
FROM files
|
||||
WHERE disk_label = %s AND category IS NULL
|
||||
""", (disk,))
|
||||
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE disk_label = %s AND category IS NULL\n ', (disk,))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT path, checksum
|
||||
FROM files
|
||||
WHERE category IS NULL
|
||||
""")
|
||||
|
||||
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE category IS NULL\n ')
|
||||
files_to_classify = cursor.fetchall()
|
||||
total_files = len(files_to_classify)
|
||||
|
||||
self.logger.info(f"Found {total_files} files to classify")
|
||||
|
||||
self.logger.info(f'Found {total_files} files to classify')
|
||||
stats = ProcessingStats()
|
||||
batch = []
|
||||
|
||||
for path_str, checksum in files_to_classify:
|
||||
path = Path(path_str)
|
||||
|
||||
# Classify using rules first
|
||||
category = self.rule_classifier.classify(path)
|
||||
|
||||
# If no rule match and ML is available, try ML
|
||||
if category is None and self.use_ml and self.ml_classifier:
|
||||
category = self.ml_classifier.classify(path)
|
||||
|
||||
# If still no category, assign default
|
||||
if category is None:
|
||||
category = "temp/processing"
|
||||
|
||||
category = 'temp/processing'
|
||||
batch.append((category, str(path)))
|
||||
stats.files_processed += 1
|
||||
|
||||
# Batch update
|
||||
if len(batch) >= batch_size:
|
||||
self._update_categories(cursor, batch)
|
||||
conn.commit()
|
||||
batch.clear()
|
||||
|
||||
# Progress callback
|
||||
if progress_callback:
|
||||
progress_callback(stats.files_processed, total_files, stats)
|
||||
|
||||
# Log progress
|
||||
if stats.files_processed % (batch_size * 10) == 0:
|
||||
self.logger.progress(
|
||||
stats.files_processed,
|
||||
total_files,
|
||||
prefix="Files classified",
|
||||
elapsed_seconds=stats.elapsed_seconds
|
||||
)
|
||||
|
||||
# Update remaining batch
|
||||
self.logger.progress(stats.files_processed, total_files, prefix='Files classified', elapsed_seconds=stats.elapsed_seconds)
|
||||
if batch:
|
||||
self._update_categories(cursor, batch)
|
||||
conn.commit()
|
||||
|
||||
stats.files_succeeded = stats.files_processed
|
||||
|
||||
cursor.close()
|
||||
|
||||
self.logger.info(
|
||||
f"Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s"
|
||||
)
|
||||
|
||||
self.logger.info(f'Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s')
|
||||
return stats
|
||||
|
||||
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
|
||||
"""Update categories in batch
|
||||
|
||||
Args:
|
||||
cursor: Database cursor
|
||||
batch: List of (category, path) tuples
|
||||
"""
|
||||
from psycopg2.extras import execute_batch
|
||||
|
||||
query = """
|
||||
UPDATE files
|
||||
SET category = %s
|
||||
WHERE path = %s
|
||||
"""
|
||||
|
||||
query = '\n UPDATE files\n SET category = %s\n WHERE path = %s\n '
|
||||
execute_batch(cursor, query, batch)
|
||||
|
||||
def classify_path(self, path: Path) -> Optional[str]:
|
||||
"""Classify a single path
|
||||
|
||||
Args:
|
||||
path: Path to classify
|
||||
|
||||
Returns:
|
||||
Category name or None
|
||||
"""
|
||||
# Try rules first
|
||||
category = self.rule_classifier.classify(path)
|
||||
|
||||
# Try ML if available
|
||||
if category is None and self.use_ml and self.ml_classifier:
|
||||
category = self.ml_classifier.classify(path)
|
||||
|
||||
return category
|
||||
|
||||
def get_category_stats(self) -> dict[str, dict]:
|
||||
"""Get statistics by category
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to statistics
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
category,
|
||||
COUNT(*) as file_count,
|
||||
SUM(size) as total_size
|
||||
FROM files
|
||||
WHERE category IS NOT NULL
|
||||
GROUP BY category
|
||||
ORDER BY total_size DESC
|
||||
""")
|
||||
|
||||
cursor.execute('\n SELECT\n category,\n COUNT(*) as file_count,\n SUM(size) as total_size\n FROM files\n WHERE category IS NOT NULL\n GROUP BY category\n ORDER BY total_size DESC\n ')
|
||||
stats = {}
|
||||
for category, file_count, total_size in cursor.fetchall():
|
||||
stats[category] = {
|
||||
'file_count': file_count,
|
||||
'total_size': total_size
|
||||
}
|
||||
|
||||
stats[category] = {'file_count': file_count, 'total_size': total_size}
|
||||
cursor.close()
|
||||
|
||||
return stats
|
||||
|
||||
def get_uncategorized_count(self) -> int:
|
||||
"""Get count of uncategorized files
|
||||
|
||||
Returns:
|
||||
Number of files without category
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM files WHERE category IS NULL")
|
||||
cursor.execute('SELECT COUNT(*) FROM files WHERE category IS NULL')
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
cursor.close()
|
||||
|
||||
return count
|
||||
|
||||
def reclassify_category(
|
||||
self,
|
||||
old_category: str,
|
||||
new_category: str
|
||||
) -> int:
|
||||
"""Reclassify all files in a category
|
||||
|
||||
Args:
|
||||
old_category: Current category
|
||||
new_category: New category
|
||||
|
||||
Returns:
|
||||
Number of files reclassified
|
||||
"""
|
||||
self.logger.info(f"Reclassifying {old_category} -> {new_category}")
|
||||
|
||||
def reclassify_category(self, old_category: str, new_category: str) -> int:
|
||||
self.logger.info(f'Reclassifying {old_category} -> {new_category}')
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE files
|
||||
SET category = %s
|
||||
WHERE category = %s
|
||||
""", (new_category, old_category))
|
||||
|
||||
cursor.execute('\n UPDATE files\n SET category = %s\n WHERE category = %s\n ', (new_category, old_category))
|
||||
count = cursor.rowcount
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
self.logger.info(f"Reclassified {count} files")
|
||||
|
||||
self.logger.info(f'Reclassified {count} files')
|
||||
return count
|
||||
|
||||
def train_ml_classifier(
|
||||
self,
|
||||
min_samples: int = 10
|
||||
) -> bool:
|
||||
"""Train ML classifier from existing categorized data
|
||||
|
||||
Args:
|
||||
min_samples: Minimum samples per category
|
||||
|
||||
Returns:
|
||||
True if training successful
|
||||
"""
|
||||
def train_ml_classifier(self, min_samples: int=10) -> bool:
|
||||
if not self.use_ml or self.ml_classifier is None:
|
||||
self.logger.warning("ML classifier not available")
|
||||
self.logger.warning('ML classifier not available')
|
||||
return False
|
||||
|
||||
self.logger.subsection("Training ML Classifier")
|
||||
|
||||
self.logger.subsection('Training ML Classifier')
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get categorized files
|
||||
cursor.execute("""
|
||||
SELECT path, category
|
||||
FROM files
|
||||
WHERE category IS NOT NULL
|
||||
""")
|
||||
|
||||
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
|
||||
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
||||
cursor.close()
|
||||
|
||||
if not training_data:
|
||||
self.logger.warning("No training data available")
|
||||
self.logger.warning('No training data available')
|
||||
return False
|
||||
|
||||
# Count samples per category
|
||||
category_counts = {}
|
||||
for _, category in training_data:
|
||||
category_counts[category] = category_counts.get(category, 0) + 1
|
||||
|
||||
# Filter categories with enough samples
|
||||
filtered_data = [
|
||||
(path, category)
|
||||
for path, category in training_data
|
||||
if category_counts[category] >= min_samples
|
||||
]
|
||||
|
||||
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples]
|
||||
if not filtered_data:
|
||||
self.logger.warning(f"No categories with >= {min_samples} samples")
|
||||
self.logger.warning(f'No categories with >= {min_samples} samples')
|
||||
return False
|
||||
|
||||
self.logger.info(f"Training with {len(filtered_data)} samples")
|
||||
|
||||
self.logger.info(f'Training with {len(filtered_data)} samples')
|
||||
try:
|
||||
self.ml_classifier.train(filtered_data)
|
||||
self.logger.info("ML classifier trained successfully")
|
||||
self.logger.info('ML classifier trained successfully')
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to train ML classifier: {e}")
|
||||
self.logger.error(f'Failed to train ML classifier: {e}')
|
||||
return False
|
||||
|
||||
def get_all_categories(self) -> list[str]:
|
||||
"""Get all categories from database
|
||||
|
||||
Returns:
|
||||
List of category names
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT DISTINCT category
|
||||
FROM files
|
||||
WHERE category IS NOT NULL
|
||||
ORDER BY category
|
||||
""")
|
||||
|
||||
cursor.execute('\n SELECT DISTINCT category\n FROM files\n WHERE category IS NOT NULL\n ORDER BY category\n ')
|
||||
categories = [row[0] for row in cursor.fetchall()]
|
||||
cursor.close()
|
||||
|
||||
return categories
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
if self._connection and not self._connection.closed:
|
||||
if self._connection and (not self._connection.closed):
|
||||
self._connection.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit"""
|
||||
self.close()
|
||||
|
||||
Reference in New Issue
Block a user