fly wa
This commit is contained in:
@@ -3,83 +3,42 @@ from typing import Dict, Set, List
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
class FolderAnalyzer:
|
class FolderAnalyzer:
|
||||||
def __init__(self):
|
|
||||||
self.manifest_files = {
|
|
||||||
'java': ['pom.xml', 'build.gradle', 'build.gradle.kts'],
|
|
||||||
'javascript': ['package.json', 'yarn.lock', 'package-lock.json'],
|
|
||||||
'python': ['pyproject.toml', 'setup.py', 'requirements.txt', 'Pipfile'],
|
|
||||||
'go': ['go.mod', 'go.sum'],
|
|
||||||
'rust': ['Cargo.toml', 'Cargo.lock'],
|
|
||||||
'docker': ['Dockerfile', 'docker-compose.yml', 'docker-compose.yaml'],
|
|
||||||
'k8s': ['helm', 'kustomization.yaml', 'deployment.yaml']
|
|
||||||
}
|
|
||||||
|
|
||||||
self.intent_keywords = {
|
def __init__(self):
|
||||||
'infrastructure': ['infra', 'deploy', 'k8s', 'docker', 'terraform', 'ansible'],
|
self.manifest_files = {'java': ['pom.xml', 'build.gradle', 'build.gradle.kts'], 'javascript': ['package.json', 'yarn.lock', 'package-lock.json'], 'python': ['pyproject.toml', 'setup.py', 'requirements.txt', 'Pipfile'], 'go': ['go.mod', 'go.sum'], 'rust': ['Cargo.toml', 'Cargo.lock'], 'docker': ['Dockerfile', 'docker-compose.yml', 'docker-compose.yaml'], 'k8s': ['helm', 'kustomization.yaml', 'deployment.yaml']}
|
||||||
'application': ['app', 'service', 'api', 'server', 'client'],
|
self.intent_keywords = {'infrastructure': ['infra', 'deploy', 'k8s', 'docker', 'terraform', 'ansible'], 'application': ['app', 'service', 'api', 'server', 'client'], 'data': ['data', 'dataset', 'models', 'training', 'ml'], 'documentation': ['docs', 'documentation', 'wiki', 'readme'], 'testing': ['test', 'tests', 'spec', 'e2e', 'integration'], 'build': ['build', 'dist', 'target', 'out', 'bin'], 'config': ['config', 'conf', 'settings', 'env']}
|
||||||
'data': ['data', 'dataset', 'models', 'training', 'ml'],
|
|
||||||
'documentation': ['docs', 'documentation', 'wiki', 'readme'],
|
|
||||||
'testing': ['test', 'tests', 'spec', 'e2e', 'integration'],
|
|
||||||
'build': ['build', 'dist', 'target', 'out', 'bin'],
|
|
||||||
'config': ['config', 'conf', 'settings', 'env']
|
|
||||||
}
|
|
||||||
|
|
||||||
def analyze_folder(self, folder_path: Path, files: List[Dict]) -> Dict:
|
def analyze_folder(self, folder_path: Path, files: List[Dict]) -> Dict:
|
||||||
files_list = [Path(f['path']) for f in files]
|
files_list = [Path(f['path']) for f in files]
|
||||||
|
has_readme = any(('readme' in f.name.lower() for f in files_list))
|
||||||
has_readme = any('readme' in f.name.lower() for f in files_list)
|
has_git = any(('.git' in str(f) for f in files_list))
|
||||||
has_git = any('.git' in str(f) for f in files_list)
|
|
||||||
|
|
||||||
manifest_types = self._detect_manifests(files_list)
|
manifest_types = self._detect_manifests(files_list)
|
||||||
has_manifest = len(manifest_types) > 0
|
has_manifest = len(manifest_types) > 0
|
||||||
|
file_types = Counter((f.suffix.lower() for f in files_list if f.suffix))
|
||||||
file_types = Counter(f.suffix.lower() for f in files_list if f.suffix)
|
|
||||||
dominant_types = dict(file_types.most_common(10))
|
dominant_types = dict(file_types.most_common(10))
|
||||||
|
|
||||||
intent = self._infer_intent(folder_path.name.lower(), files_list)
|
intent = self._infer_intent(folder_path.name.lower(), files_list)
|
||||||
project_type = self._infer_project_type(manifest_types, dominant_types)
|
project_type = self._infer_project_type(manifest_types, dominant_types)
|
||||||
|
structure = {'depth': len(folder_path.parts), 'has_src': any(('src' in str(f) for f in files_list[:20])), 'has_tests': any(('test' in str(f) for f in files_list[:20])), 'has_docs': any(('doc' in str(f) for f in files_list[:20]))}
|
||||||
structure = {
|
return {'has_readme': has_readme, 'has_git': has_git, 'has_manifest': has_manifest, 'manifest_types': manifest_types, 'dominant_file_types': dominant_types, 'project_type': project_type, 'intent': intent, 'structure': structure}
|
||||||
'depth': len(folder_path.parts),
|
|
||||||
'has_src': any('src' in str(f) for f in files_list[:20]),
|
|
||||||
'has_tests': any('test' in str(f) for f in files_list[:20]),
|
|
||||||
'has_docs': any('doc' in str(f) for f in files_list[:20])
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
'has_readme': has_readme,
|
|
||||||
'has_git': has_git,
|
|
||||||
'has_manifest': has_manifest,
|
|
||||||
'manifest_types': manifest_types,
|
|
||||||
'dominant_file_types': dominant_types,
|
|
||||||
'project_type': project_type,
|
|
||||||
'intent': intent,
|
|
||||||
'structure': structure
|
|
||||||
}
|
|
||||||
|
|
||||||
def _detect_manifests(self, files: List[Path]) -> List[str]:
|
def _detect_manifests(self, files: List[Path]) -> List[str]:
|
||||||
detected = []
|
detected = []
|
||||||
file_names = {f.name for f in files}
|
file_names = {f.name for f in files}
|
||||||
|
|
||||||
for tech, manifests in self.manifest_files.items():
|
for tech, manifests in self.manifest_files.items():
|
||||||
if any(m in file_names for m in manifests):
|
if any((m in file_names for m in manifests)):
|
||||||
detected.append(tech)
|
detected.append(tech)
|
||||||
|
|
||||||
return detected
|
return detected
|
||||||
|
|
||||||
def _infer_intent(self, folder_name: str, files: List[Path]) -> str:
|
def _infer_intent(self, folder_name: str, files: List[Path]) -> str:
|
||||||
file_str = ' '.join(str(f) for f in files[:50])
|
file_str = ' '.join((str(f) for f in files[:50]))
|
||||||
|
|
||||||
for intent, keywords in self.intent_keywords.items():
|
for intent, keywords in self.intent_keywords.items():
|
||||||
if any(kw in folder_name or kw in file_str.lower() for kw in keywords):
|
if any((kw in folder_name or kw in file_str.lower() for kw in keywords)):
|
||||||
return intent
|
return intent
|
||||||
|
|
||||||
return 'unknown'
|
return 'unknown'
|
||||||
|
|
||||||
def _infer_project_type(self, manifests: List[str], file_types: Dict) -> str:
|
def _infer_project_type(self, manifests: List[str], file_types: Dict) -> str:
|
||||||
if manifests:
|
if manifests:
|
||||||
return manifests[0]
|
return manifests[0]
|
||||||
|
|
||||||
if '.py' in file_types and file_types.get('.py', 0) > 5:
|
if '.py' in file_types and file_types.get('.py', 0) > 5:
|
||||||
return 'python'
|
return 'python'
|
||||||
if '.js' in file_types or '.ts' in file_types:
|
if '.js' in file_types or '.ts' in file_types:
|
||||||
@@ -88,23 +47,17 @@ class FolderAnalyzer:
|
|||||||
return 'java'
|
return 'java'
|
||||||
if '.go' in file_types:
|
if '.go' in file_types:
|
||||||
return 'go'
|
return 'go'
|
||||||
|
|
||||||
return 'mixed'
|
return 'mixed'
|
||||||
|
|
||||||
def generate_summary(self, folder_analysis: Dict, readme_text: str = None) -> str:
|
def generate_summary(self, folder_analysis: Dict, readme_text: str=None) -> str:
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
if folder_analysis.get('project_type'):
|
if folder_analysis.get('project_type'):
|
||||||
parts.append(f"{folder_analysis['project_type']} project")
|
parts.append(f"{folder_analysis['project_type']} project")
|
||||||
|
|
||||||
if folder_analysis.get('intent'):
|
if folder_analysis.get('intent'):
|
||||||
parts.append(f"for {folder_analysis['intent']}")
|
parts.append(f"for {folder_analysis['intent']}")
|
||||||
|
|
||||||
if folder_analysis.get('manifest_types'):
|
if folder_analysis.get('manifest_types'):
|
||||||
parts.append(f"using {', '.join(folder_analysis['manifest_types'])}")
|
parts.append(f"using {', '.join(folder_analysis['manifest_types'])}")
|
||||||
|
|
||||||
if readme_text:
|
if readme_text:
|
||||||
first_para = readme_text.split('\n\n')[0][:200]
|
first_para = readme_text.split('\n\n')[0][:200]
|
||||||
parts.append(f"Description: {first_para}")
|
parts.append(f'Description: {first_para}')
|
||||||
|
|
||||||
return ' '.join(parts) if parts else 'Mixed content folder'
|
return ' '.join(parts) if parts else 'Mixed content folder'
|
||||||
|
|||||||
@@ -1,3 +1,2 @@
|
|||||||
from .classifier import FileClassifier
|
from .classifier import FileClassifier
|
||||||
|
|
||||||
__all__ = ['FileClassifier']
|
__all__ = ['FileClassifier']
|
||||||
|
|||||||
@@ -1,72 +1,30 @@
|
|||||||
"""Protocol definitions for the classification package"""
|
|
||||||
from typing import Protocol, Optional
|
from typing import Protocol, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassificationRule:
|
class ClassificationRule:
|
||||||
"""Rule for classifying files"""
|
|
||||||
name: str
|
name: str
|
||||||
category: str
|
category: str
|
||||||
patterns: list[str]
|
patterns: list[str]
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
description: str = ""
|
description: str = ''
|
||||||
|
|
||||||
|
|
||||||
class IClassifier(Protocol):
|
class IClassifier(Protocol):
|
||||||
"""Protocol for classification operations"""
|
|
||||||
|
|
||||||
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
|
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
|
||||||
"""Classify a file path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to classify
|
|
||||||
file_type: Optional file type hint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Category name or None if no match
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_category_rules(self, category: str) -> list[ClassificationRule]:
|
def get_category_rules(self, category: str) -> list[ClassificationRule]:
|
||||||
"""Get all rules for a category
|
|
||||||
|
|
||||||
Args:
|
|
||||||
category: Category name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of rules for the category
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class IRuleEngine(Protocol):
|
class IRuleEngine(Protocol):
|
||||||
"""Protocol for rule-based classification"""
|
|
||||||
|
|
||||||
def add_rule(self, rule: ClassificationRule) -> None:
|
def add_rule(self, rule: ClassificationRule) -> None:
|
||||||
"""Add a classification rule
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rule: Rule to add
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
def remove_rule(self, rule_name: str) -> None:
|
def remove_rule(self, rule_name: str) -> None:
|
||||||
"""Remove a rule by name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rule_name: Name of rule to remove
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
def match_path(self, path: Path) -> Optional[str]:
|
def match_path(self, path: Path) -> Optional[str]:
|
||||||
"""Match path against rules
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to match
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Category name or None if no match
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -3,122 +3,72 @@ from typing import List, Set, Dict, Tuple
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
class FileClassifier:
|
class FileClassifier:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.build_patterns = {
|
self.build_patterns = {'node_modules', '__pycache__', '.pytest_cache', 'target', 'build', 'dist', '.gradle', 'bin', 'obj', '.next', '.nuxt', 'vendor', '.venv', 'venv', 'site-packages', 'bower_components', 'jspm_packages'}
|
||||||
'node_modules', '__pycache__', '.pytest_cache', 'target', 'build', 'dist',
|
self.artifact_patterns = {'java': {'.jar', '.war', '.ear', '.class'}, 'python': {'.pyc', '.pyo', '.whl', '.egg'}, 'node': {'node_modules'}, 'go': {'vendor', 'pkg'}, 'rust': {'target'}, 'docker': {'.dockerignore', 'Dockerfile'}}
|
||||||
'.gradle', 'bin', 'obj', '.next', '.nuxt', 'vendor', '.venv', 'venv',
|
self.category_keywords = {'apps': {'app', 'application', 'service', 'api', 'server', 'client'}, 'infra': {'infrastructure', 'devops', 'docker', 'kubernetes', 'terraform', 'ansible', 'gitea', 'jenkins'}, 'dev': {'project', 'workspace', 'repo', 'src', 'code', 'dev'}, 'cache': {'cache', 'temp', 'tmp', '.cache'}, 'databases': {'postgres', 'mysql', 'redis', 'mongo', 'db', 'database'}, 'backups': {'backup', 'bak', 'snapshot', 'archive'}, 'user': {'documents', 'pictures', 'videos', 'downloads', 'desktop', 'music'}, 'artifacts': {'build', 'dist', 'release', 'output'}, 'temp': {'tmp', 'temp', 'staging', 'processing'}}
|
||||||
'site-packages', 'bower_components', 'jspm_packages'
|
self.media_extensions = {'video': {'.mp4', '.mkv', '.avi', '.mov', '.wmv', '.flv', '.webm'}, 'audio': {'.mp3', '.flac', '.wav', '.ogg', '.m4a', '.aac'}, 'image': {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg', '.webp'}, 'document': {'.pdf', '.doc', '.docx', '.txt', '.md', '.odt'}, 'spreadsheet': {'.xls', '.xlsx', '.csv', '.ods'}, 'presentation': {'.ppt', '.pptx', '.odp'}}
|
||||||
}
|
self.code_extensions = {'.py', '.js', '.ts', '.java', '.go', '.rs', '.c', '.cpp', '.h', '.cs', '.rb', '.php', '.swift', '.kt', '.scala', '.clj', '.r'}
|
||||||
|
|
||||||
self.artifact_patterns = {
|
def classify_path(self, path: str, size: int=0) -> Tuple[Set[str], str, bool]:
|
||||||
'java': {'.jar', '.war', '.ear', '.class'},
|
|
||||||
'python': {'.pyc', '.pyo', '.whl', '.egg'},
|
|
||||||
'node': {'node_modules'},
|
|
||||||
'go': {'vendor', 'pkg'},
|
|
||||||
'rust': {'target'},
|
|
||||||
'docker': {'.dockerignore', 'Dockerfile'}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.category_keywords = {
|
|
||||||
'apps': {'app', 'application', 'service', 'api', 'server', 'client'},
|
|
||||||
'infra': {'infrastructure', 'devops', 'docker', 'kubernetes', 'terraform', 'ansible', 'gitea', 'jenkins'},
|
|
||||||
'dev': {'project', 'workspace', 'repo', 'src', 'code', 'dev'},
|
|
||||||
'cache': {'cache', 'temp', 'tmp', '.cache'},
|
|
||||||
'databases': {'postgres', 'mysql', 'redis', 'mongo', 'db', 'database'},
|
|
||||||
'backups': {'backup', 'bak', 'snapshot', 'archive'},
|
|
||||||
'user': {'documents', 'pictures', 'videos', 'downloads', 'desktop', 'music'},
|
|
||||||
'artifacts': {'build', 'dist', 'release', 'output'},
|
|
||||||
'temp': {'tmp', 'temp', 'staging', 'processing'}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.media_extensions = {
|
|
||||||
'video': {'.mp4', '.mkv', '.avi', '.mov', '.wmv', '.flv', '.webm'},
|
|
||||||
'audio': {'.mp3', '.flac', '.wav', '.ogg', '.m4a', '.aac'},
|
|
||||||
'image': {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg', '.webp'},
|
|
||||||
'document': {'.pdf', '.doc', '.docx', '.txt', '.md', '.odt'},
|
|
||||||
'spreadsheet': {'.xls', '.xlsx', '.csv', '.ods'},
|
|
||||||
'presentation': {'.ppt', '.pptx', '.odp'}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.code_extensions = {
|
|
||||||
'.py', '.js', '.ts', '.java', '.go', '.rs', '.c', '.cpp', '.h',
|
|
||||||
'.cs', '.rb', '.php', '.swift', '.kt', '.scala', '.clj', '.r'
|
|
||||||
}
|
|
||||||
|
|
||||||
def classify_path(self, path: str, size: int = 0) -> Tuple[Set[str], str, bool]:
|
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
labels = set()
|
labels = set()
|
||||||
primary_category = 'misc'
|
primary_category = 'misc'
|
||||||
is_build_artifact = False
|
is_build_artifact = False
|
||||||
|
|
||||||
parts = p.parts
|
parts = p.parts
|
||||||
name_lower = p.name.lower()
|
name_lower = p.name.lower()
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
part_lower = part.lower()
|
part_lower = part.lower()
|
||||||
if part_lower in self.build_patterns:
|
if part_lower in self.build_patterns:
|
||||||
is_build_artifact = True
|
is_build_artifact = True
|
||||||
labels.add('build-artifact')
|
labels.add('build-artifact')
|
||||||
break
|
break
|
||||||
|
|
||||||
if is_build_artifact:
|
if is_build_artifact:
|
||||||
for artifact_type, patterns in self.artifact_patterns.items():
|
for artifact_type, patterns in self.artifact_patterns.items():
|
||||||
if any(part.lower() in patterns for part in parts) or p.suffix in patterns:
|
if any((part.lower() in patterns for part in parts)) or p.suffix in patterns:
|
||||||
primary_category = f'artifacts/{artifact_type}'
|
primary_category = f'artifacts/{artifact_type}'
|
||||||
labels.add('artifact')
|
labels.add('artifact')
|
||||||
return labels, primary_category, is_build_artifact
|
return (labels, primary_category, is_build_artifact)
|
||||||
|
|
||||||
if '.git' in parts:
|
if '.git' in parts:
|
||||||
labels.add('vcs')
|
labels.add('vcs')
|
||||||
primary_category = 'infra/git-infrastructure'
|
primary_category = 'infra/git-infrastructure'
|
||||||
return labels, primary_category, False
|
return (labels, primary_category, False)
|
||||||
|
|
||||||
for category, keywords in self.category_keywords.items():
|
for category, keywords in self.category_keywords.items():
|
||||||
if any(kw in name_lower or any(kw in part.lower() for part in parts) for kw in keywords):
|
if any((kw in name_lower or any((kw in part.lower() for part in parts)) for kw in keywords)):
|
||||||
labels.add(category)
|
labels.add(category)
|
||||||
primary_category = category
|
primary_category = category
|
||||||
break
|
break
|
||||||
|
|
||||||
for media_type, extensions in self.media_extensions.items():
|
for media_type, extensions in self.media_extensions.items():
|
||||||
if p.suffix.lower() in extensions:
|
if p.suffix.lower() in extensions:
|
||||||
labels.add(media_type)
|
labels.add(media_type)
|
||||||
labels.add('media')
|
labels.add('media')
|
||||||
primary_category = f'user/{media_type}'
|
primary_category = f'user/{media_type}'
|
||||||
break
|
break
|
||||||
|
|
||||||
if p.suffix.lower() in self.code_extensions:
|
if p.suffix.lower() in self.code_extensions:
|
||||||
labels.add('code')
|
labels.add('code')
|
||||||
if primary_category == 'misc':
|
if primary_category == 'misc':
|
||||||
primary_category = 'dev'
|
primary_category = 'dev'
|
||||||
|
|
||||||
if size > 100 * 1024 * 1024:
|
if size > 100 * 1024 * 1024:
|
||||||
labels.add('large-file')
|
labels.add('large-file')
|
||||||
|
if any((kw in name_lower for kw in ['test', 'spec', 'mock'])):
|
||||||
if any(kw in name_lower for kw in ['test', 'spec', 'mock']):
|
|
||||||
labels.add('test')
|
labels.add('test')
|
||||||
|
if any((kw in name_lower for kw in ['config', 'settings', 'env'])):
|
||||||
if any(kw in name_lower for kw in ['config', 'settings', 'env']):
|
|
||||||
labels.add('config')
|
labels.add('config')
|
||||||
|
return (labels, primary_category, is_build_artifact)
|
||||||
return labels, primary_category, is_build_artifact
|
|
||||||
|
|
||||||
def suggest_target_path(self, source_path: str, category: str, labels: Set[str]) -> str:
|
def suggest_target_path(self, source_path: str, category: str, labels: Set[str]) -> str:
|
||||||
p = Path(source_path)
|
p = Path(source_path)
|
||||||
|
|
||||||
if 'build-artifact' in labels:
|
if 'build-artifact' in labels:
|
||||||
return f'trash/build-artifacts/{source_path}'
|
return f'trash/build-artifacts/{source_path}'
|
||||||
|
|
||||||
if category.startswith('artifacts/'):
|
if category.startswith('artifacts/'):
|
||||||
artifact_type = category.split('/')[-1]
|
artifact_type = category.split('/')[-1]
|
||||||
return f'artifacts/{artifact_type}/{p.name}'
|
return f'artifacts/{artifact_type}/{p.name}'
|
||||||
|
|
||||||
if category.startswith('user/'):
|
if category.startswith('user/'):
|
||||||
media_type = category.split('/')[-1]
|
media_type = category.split('/')[-1]
|
||||||
return f'user/{media_type}/{p.name}'
|
return f'user/{media_type}/{p.name}'
|
||||||
|
|
||||||
parts = [part for part in p.parts if part not in self.build_patterns]
|
parts = [part for part in p.parts if part not in self.build_patterns]
|
||||||
if len(parts) > 3:
|
if len(parts) > 3:
|
||||||
project_name = parts[0] if parts else 'misc'
|
project_name = parts[0] if parts else 'misc'
|
||||||
return f'{category}/{project_name}/{"/".join(parts[1:])}'
|
return f"{category}/{project_name}/{'/'.join(parts[1:])}"
|
||||||
|
|
||||||
return f'{category}/{source_path}'
|
return f'{category}/{source_path}'
|
||||||
|
|||||||
@@ -1,350 +1,148 @@
|
|||||||
"""Main classification engine"""
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
|
||||||
from .rules import RuleBasedClassifier
|
from .rules import RuleBasedClassifier
|
||||||
from .ml import create_ml_classifier, DummyMLClassifier
|
from .ml import create_ml_classifier, DummyMLClassifier
|
||||||
from ..shared.models import ProcessingStats
|
from ..shared.models import ProcessingStats
|
||||||
from ..shared.config import DatabaseConfig
|
from ..shared.config import DatabaseConfig
|
||||||
from ..shared.logger import ProgressLogger
|
from ..shared.logger import ProgressLogger
|
||||||
|
|
||||||
|
|
||||||
class ClassificationEngine:
|
class ClassificationEngine:
|
||||||
"""Engine for classifying files"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_config: DatabaseConfig, logger: ProgressLogger, use_ml: bool=False):
|
||||||
self,
|
|
||||||
db_config: DatabaseConfig,
|
|
||||||
logger: ProgressLogger,
|
|
||||||
use_ml: bool = False
|
|
||||||
):
|
|
||||||
"""Initialize classification engine
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_config: Database configuration
|
|
||||||
logger: Progress logger
|
|
||||||
use_ml: Whether to use ML classification in addition to rules
|
|
||||||
"""
|
|
||||||
self.db_config = db_config
|
self.db_config = db_config
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.rule_classifier = RuleBasedClassifier()
|
self.rule_classifier = RuleBasedClassifier()
|
||||||
self.ml_classifier = create_ml_classifier() if use_ml else None
|
self.ml_classifier = create_ml_classifier() if use_ml else None
|
||||||
self.use_ml = use_ml and not isinstance(self.ml_classifier, DummyMLClassifier)
|
self.use_ml = use_ml and (not isinstance(self.ml_classifier, DummyMLClassifier))
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
|
||||||
def _get_connection(self):
|
def _get_connection(self):
|
||||||
"""Get or create database connection"""
|
|
||||||
if self._connection is None or self._connection.closed:
|
if self._connection is None or self._connection.closed:
|
||||||
self._connection = psycopg2.connect(
|
self._connection = psycopg2.connect(host=self.db_config.host, port=self.db_config.port, database=self.db_config.database, user=self.db_config.user, password=self.db_config.password)
|
||||||
host=self.db_config.host,
|
|
||||||
port=self.db_config.port,
|
|
||||||
database=self.db_config.database,
|
|
||||||
user=self.db_config.user,
|
|
||||||
password=self.db_config.password
|
|
||||||
)
|
|
||||||
return self._connection
|
return self._connection
|
||||||
|
|
||||||
def classify_all(
|
def classify_all(self, disk: Optional[str]=None, batch_size: int=1000, progress_callback: Optional[Callable[[int, int, ProcessingStats], None]]=None) -> ProcessingStats:
|
||||||
self,
|
self.logger.section('Starting Classification')
|
||||||
disk: Optional[str] = None,
|
|
||||||
batch_size: int = 1000,
|
|
||||||
progress_callback: Optional[Callable[[int, int, ProcessingStats], None]] = None
|
|
||||||
) -> ProcessingStats:
|
|
||||||
"""Classify all files in database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
disk: Optional disk filter
|
|
||||||
batch_size: Number of files to process per batch
|
|
||||||
progress_callback: Optional callback for progress updates
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProcessingStats with classification statistics
|
|
||||||
"""
|
|
||||||
self.logger.section("Starting Classification")
|
|
||||||
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Get files without categories
|
|
||||||
if disk:
|
if disk:
|
||||||
cursor.execute("""
|
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE disk_label = %s AND category IS NULL\n ', (disk,))
|
||||||
SELECT path, checksum
|
|
||||||
FROM files
|
|
||||||
WHERE disk_label = %s AND category IS NULL
|
|
||||||
""", (disk,))
|
|
||||||
else:
|
else:
|
||||||
cursor.execute("""
|
cursor.execute('\n SELECT path, checksum\n FROM files\n WHERE category IS NULL\n ')
|
||||||
SELECT path, checksum
|
|
||||||
FROM files
|
|
||||||
WHERE category IS NULL
|
|
||||||
""")
|
|
||||||
|
|
||||||
files_to_classify = cursor.fetchall()
|
files_to_classify = cursor.fetchall()
|
||||||
total_files = len(files_to_classify)
|
total_files = len(files_to_classify)
|
||||||
|
self.logger.info(f'Found {total_files} files to classify')
|
||||||
self.logger.info(f"Found {total_files} files to classify")
|
|
||||||
|
|
||||||
stats = ProcessingStats()
|
stats = ProcessingStats()
|
||||||
batch = []
|
batch = []
|
||||||
|
|
||||||
for path_str, checksum in files_to_classify:
|
for path_str, checksum in files_to_classify:
|
||||||
path = Path(path_str)
|
path = Path(path_str)
|
||||||
|
|
||||||
# Classify using rules first
|
|
||||||
category = self.rule_classifier.classify(path)
|
category = self.rule_classifier.classify(path)
|
||||||
|
|
||||||
# If no rule match and ML is available, try ML
|
|
||||||
if category is None and self.use_ml and self.ml_classifier:
|
if category is None and self.use_ml and self.ml_classifier:
|
||||||
category = self.ml_classifier.classify(path)
|
category = self.ml_classifier.classify(path)
|
||||||
|
|
||||||
# If still no category, assign default
|
|
||||||
if category is None:
|
if category is None:
|
||||||
category = "temp/processing"
|
category = 'temp/processing'
|
||||||
|
|
||||||
batch.append((category, str(path)))
|
batch.append((category, str(path)))
|
||||||
stats.files_processed += 1
|
stats.files_processed += 1
|
||||||
|
|
||||||
# Batch update
|
|
||||||
if len(batch) >= batch_size:
|
if len(batch) >= batch_size:
|
||||||
self._update_categories(cursor, batch)
|
self._update_categories(cursor, batch)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
batch.clear()
|
batch.clear()
|
||||||
|
|
||||||
# Progress callback
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(stats.files_processed, total_files, stats)
|
progress_callback(stats.files_processed, total_files, stats)
|
||||||
|
|
||||||
# Log progress
|
|
||||||
if stats.files_processed % (batch_size * 10) == 0:
|
if stats.files_processed % (batch_size * 10) == 0:
|
||||||
self.logger.progress(
|
self.logger.progress(stats.files_processed, total_files, prefix='Files classified', elapsed_seconds=stats.elapsed_seconds)
|
||||||
stats.files_processed,
|
|
||||||
total_files,
|
|
||||||
prefix="Files classified",
|
|
||||||
elapsed_seconds=stats.elapsed_seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update remaining batch
|
|
||||||
if batch:
|
if batch:
|
||||||
self._update_categories(cursor, batch)
|
self._update_categories(cursor, batch)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
stats.files_succeeded = stats.files_processed
|
stats.files_succeeded = stats.files_processed
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
self.logger.info(f'Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s')
|
||||||
self.logger.info(
|
|
||||||
f"Classification complete: {stats.files_processed} files in {stats.elapsed_seconds:.1f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
|
def _update_categories(self, cursor, batch: list[tuple[str, str]]):
|
||||||
"""Update categories in batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cursor: Database cursor
|
|
||||||
batch: List of (category, path) tuples
|
|
||||||
"""
|
|
||||||
from psycopg2.extras import execute_batch
|
from psycopg2.extras import execute_batch
|
||||||
|
query = '\n UPDATE files\n SET category = %s\n WHERE path = %s\n '
|
||||||
query = """
|
|
||||||
UPDATE files
|
|
||||||
SET category = %s
|
|
||||||
WHERE path = %s
|
|
||||||
"""
|
|
||||||
|
|
||||||
execute_batch(cursor, query, batch)
|
execute_batch(cursor, query, batch)
|
||||||
|
|
||||||
def classify_path(self, path: Path) -> Optional[str]:
|
def classify_path(self, path: Path) -> Optional[str]:
|
||||||
"""Classify a single path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to classify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Category name or None
|
|
||||||
"""
|
|
||||||
# Try rules first
|
|
||||||
category = self.rule_classifier.classify(path)
|
category = self.rule_classifier.classify(path)
|
||||||
|
|
||||||
# Try ML if available
|
|
||||||
if category is None and self.use_ml and self.ml_classifier:
|
if category is None and self.use_ml and self.ml_classifier:
|
||||||
category = self.ml_classifier.classify(path)
|
category = self.ml_classifier.classify(path)
|
||||||
|
|
||||||
return category
|
return category
|
||||||
|
|
||||||
def get_category_stats(self) -> dict[str, dict]:
|
def get_category_stats(self) -> dict[str, dict]:
|
||||||
"""Get statistics by category
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping category to statistics
|
|
||||||
"""
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('\n SELECT\n category,\n COUNT(*) as file_count,\n SUM(size) as total_size\n FROM files\n WHERE category IS NOT NULL\n GROUP BY category\n ORDER BY total_size DESC\n ')
|
||||||
cursor.execute("""
|
|
||||||
SELECT
|
|
||||||
category,
|
|
||||||
COUNT(*) as file_count,
|
|
||||||
SUM(size) as total_size
|
|
||||||
FROM files
|
|
||||||
WHERE category IS NOT NULL
|
|
||||||
GROUP BY category
|
|
||||||
ORDER BY total_size DESC
|
|
||||||
""")
|
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
for category, file_count, total_size in cursor.fetchall():
|
for category, file_count, total_size in cursor.fetchall():
|
||||||
stats[category] = {
|
stats[category] = {'file_count': file_count, 'total_size': total_size}
|
||||||
'file_count': file_count,
|
|
||||||
'total_size': total_size
|
|
||||||
}
|
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def get_uncategorized_count(self) -> int:
|
def get_uncategorized_count(self) -> int:
|
||||||
"""Get count of uncategorized files
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of files without category
|
|
||||||
"""
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('SELECT COUNT(*) FROM files WHERE category IS NULL')
|
||||||
cursor.execute("SELECT COUNT(*) FROM files WHERE category IS NULL")
|
|
||||||
count = cursor.fetchone()[0]
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def reclassify_category(
|
def reclassify_category(self, old_category: str, new_category: str) -> int:
|
||||||
self,
|
self.logger.info(f'Reclassifying {old_category} -> {new_category}')
|
||||||
old_category: str,
|
|
||||||
new_category: str
|
|
||||||
) -> int:
|
|
||||||
"""Reclassify all files in a category
|
|
||||||
|
|
||||||
Args:
|
|
||||||
old_category: Current category
|
|
||||||
new_category: New category
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of files reclassified
|
|
||||||
"""
|
|
||||||
self.logger.info(f"Reclassifying {old_category} -> {new_category}")
|
|
||||||
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('\n UPDATE files\n SET category = %s\n WHERE category = %s\n ', (new_category, old_category))
|
||||||
cursor.execute("""
|
|
||||||
UPDATE files
|
|
||||||
SET category = %s
|
|
||||||
WHERE category = %s
|
|
||||||
""", (new_category, old_category))
|
|
||||||
|
|
||||||
count = cursor.rowcount
|
count = cursor.rowcount
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
self.logger.info(f'Reclassified {count} files')
|
||||||
self.logger.info(f"Reclassified {count} files")
|
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def train_ml_classifier(
|
def train_ml_classifier(self, min_samples: int=10) -> bool:
|
||||||
self,
|
|
||||||
min_samples: int = 10
|
|
||||||
) -> bool:
|
|
||||||
"""Train ML classifier from existing categorized data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_samples: Minimum samples per category
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if training successful
|
|
||||||
"""
|
|
||||||
if not self.use_ml or self.ml_classifier is None:
|
if not self.use_ml or self.ml_classifier is None:
|
||||||
self.logger.warning("ML classifier not available")
|
self.logger.warning('ML classifier not available')
|
||||||
return False
|
return False
|
||||||
|
self.logger.subsection('Training ML Classifier')
|
||||||
self.logger.subsection("Training ML Classifier")
|
|
||||||
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
|
||||||
# Get categorized files
|
|
||||||
cursor.execute("""
|
|
||||||
SELECT path, category
|
|
||||||
FROM files
|
|
||||||
WHERE category IS NOT NULL
|
|
||||||
""")
|
|
||||||
|
|
||||||
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
if not training_data:
|
if not training_data:
|
||||||
self.logger.warning("No training data available")
|
self.logger.warning('No training data available')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Count samples per category
|
|
||||||
category_counts = {}
|
category_counts = {}
|
||||||
for _, category in training_data:
|
for _, category in training_data:
|
||||||
category_counts[category] = category_counts.get(category, 0) + 1
|
category_counts[category] = category_counts.get(category, 0) + 1
|
||||||
|
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples]
|
||||||
# Filter categories with enough samples
|
|
||||||
filtered_data = [
|
|
||||||
(path, category)
|
|
||||||
for path, category in training_data
|
|
||||||
if category_counts[category] >= min_samples
|
|
||||||
]
|
|
||||||
|
|
||||||
if not filtered_data:
|
if not filtered_data:
|
||||||
self.logger.warning(f"No categories with >= {min_samples} samples")
|
self.logger.warning(f'No categories with >= {min_samples} samples')
|
||||||
return False
|
return False
|
||||||
|
self.logger.info(f'Training with {len(filtered_data)} samples')
|
||||||
self.logger.info(f"Training with {len(filtered_data)} samples")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.ml_classifier.train(filtered_data)
|
self.ml_classifier.train(filtered_data)
|
||||||
self.logger.info("ML classifier trained successfully")
|
self.logger.info('ML classifier trained successfully')
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Failed to train ML classifier: {e}")
|
self.logger.error(f'Failed to train ML classifier: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_all_categories(self) -> list[str]:
|
def get_all_categories(self) -> list[str]:
|
||||||
"""Get all categories from database
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of category names
|
|
||||||
"""
|
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
cursor.execute('\n SELECT DISTINCT category\n FROM files\n WHERE category IS NOT NULL\n ORDER BY category\n ')
|
||||||
cursor.execute("""
|
|
||||||
SELECT DISTINCT category
|
|
||||||
FROM files
|
|
||||||
WHERE category IS NOT NULL
|
|
||||||
ORDER BY category
|
|
||||||
""")
|
|
||||||
|
|
||||||
categories = [row[0] for row in cursor.fetchall()]
|
categories = [row[0] for row in cursor.fetchall()]
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
return categories
|
return categories
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close database connection"""
|
if self._connection and (not self._connection.closed):
|
||||||
if self._connection and not self._connection.closed:
|
|
||||||
self._connection.close()
|
self._connection.close()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Context manager entry"""
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Context manager exit"""
|
|
||||||
self.close()
|
self.close()
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
"""ML-based classification (optional, using sklearn if available)"""
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.naive_bayes import MultinomialNB
|
from sklearn.naive_bayes import MultinomialNB
|
||||||
@@ -11,100 +9,41 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
SKLEARN_AVAILABLE = False
|
SKLEARN_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
class MLClassifier:
|
class MLClassifier:
|
||||||
"""Machine learning-based file classifier
|
|
||||||
|
|
||||||
Uses path-based features and optional metadata to classify files.
|
|
||||||
Requires scikit-learn to be installed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ML classifier"""
|
|
||||||
if not SKLEARN_AVAILABLE:
|
if not SKLEARN_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError('scikit-learn is required for ML classification. Install with: pip install scikit-learn')
|
||||||
"scikit-learn is required for ML classification. "
|
|
||||||
"Install with: pip install scikit-learn"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model: Optional[Pipeline] = None
|
self.model: Optional[Pipeline] = None
|
||||||
self.categories: List[str] = []
|
self.categories: List[str] = []
|
||||||
self._is_trained = False
|
self._is_trained = False
|
||||||
|
|
||||||
def _extract_features(self, path: Path) -> str:
|
def _extract_features(self, path: Path) -> str:
|
||||||
"""Extract features from path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to extract features from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Feature string
|
|
||||||
"""
|
|
||||||
# Convert path to feature string
|
|
||||||
# Include: path parts, extension, filename
|
|
||||||
parts = path.parts
|
parts = path.parts
|
||||||
extension = path.suffix
|
extension = path.suffix
|
||||||
filename = path.name
|
filename = path.name
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
|
|
||||||
# Add path components
|
|
||||||
features.extend(parts)
|
features.extend(parts)
|
||||||
|
|
||||||
# Add extension
|
|
||||||
if extension:
|
if extension:
|
||||||
features.append(f"ext:{extension}")
|
features.append(f'ext:{extension}')
|
||||||
|
|
||||||
# Add filename components (split on common separators)
|
|
||||||
name_parts = filename.replace('-', ' ').replace('_', ' ').replace('.', ' ').split()
|
name_parts = filename.replace('-', ' ').replace('_', ' ').replace('.', ' ').split()
|
||||||
features.extend([f"name:{part}" for part in name_parts])
|
features.extend([f'name:{part}' for part in name_parts])
|
||||||
|
|
||||||
return ' '.join(features)
|
return ' '.join(features)
|
||||||
|
|
||||||
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
||||||
"""Train the classifier
|
|
||||||
|
|
||||||
Args:
|
|
||||||
training_data: List of (path, category) tuples
|
|
||||||
"""
|
|
||||||
if not training_data:
|
if not training_data:
|
||||||
raise ValueError("Training data cannot be empty")
|
raise ValueError('Training data cannot be empty')
|
||||||
|
|
||||||
# Extract features and labels
|
|
||||||
X = [self._extract_features(path) for path, _ in training_data]
|
X = [self._extract_features(path) for path, _ in training_data]
|
||||||
y = [category for _, category in training_data]
|
y = [category for _, category in training_data]
|
||||||
|
|
||||||
# Store unique categories
|
|
||||||
self.categories = sorted(set(y))
|
self.categories = sorted(set(y))
|
||||||
|
self.model = Pipeline([('tfidf', TfidfVectorizer(max_features=1000, ngram_range=(1, 2), min_df=1)), ('classifier', MultinomialNB())])
|
||||||
# Create and train pipeline
|
|
||||||
self.model = Pipeline([
|
|
||||||
('tfidf', TfidfVectorizer(
|
|
||||||
max_features=1000,
|
|
||||||
ngram_range=(1, 2),
|
|
||||||
min_df=1
|
|
||||||
)),
|
|
||||||
('classifier', MultinomialNB())
|
|
||||||
])
|
|
||||||
|
|
||||||
self.model.fit(X, y)
|
self.model.fit(X, y)
|
||||||
self._is_trained = True
|
self._is_trained = True
|
||||||
|
|
||||||
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
|
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
|
||||||
"""Classify a file path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to classify
|
|
||||||
file_type: Optional file type hint (not used in ML classifier)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Category name or None if not trained
|
|
||||||
"""
|
|
||||||
if not self._is_trained or self.model is None:
|
if not self._is_trained or self.model is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
features = self._extract_features(path)
|
features = self._extract_features(path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prediction = self.model.predict([features])[0]
|
prediction = self.model.predict([features])[0]
|
||||||
return prediction
|
return prediction
|
||||||
@@ -112,158 +51,77 @@ class MLClassifier:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_proba(self, path: Path) -> dict[str, float]:
|
def predict_proba(self, path: Path) -> dict[str, float]:
|
||||||
"""Get prediction probabilities for all categories
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to classify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping category to probability
|
|
||||||
"""
|
|
||||||
if not self._is_trained or self.model is None:
|
if not self._is_trained or self.model is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
features = self._extract_features(path)
|
features = self._extract_features(path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
probabilities = self.model.predict_proba([features])[0]
|
probabilities = self.model.predict_proba([features])[0]
|
||||||
return {
|
return {category: float(prob) for category, prob in zip(self.categories, probabilities)}
|
||||||
category: float(prob)
|
|
||||||
for category, prob in zip(self.categories, probabilities)
|
|
||||||
}
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def save_model(self, model_path: Path) -> None:
|
def save_model(self, model_path: Path) -> None:
|
||||||
"""Save trained model to disk
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to save model
|
|
||||||
"""
|
|
||||||
if not self._is_trained:
|
if not self._is_trained:
|
||||||
raise ValueError("Cannot save untrained model")
|
raise ValueError('Cannot save untrained model')
|
||||||
|
model_data = {'model': self.model, 'categories': self.categories, 'is_trained': self._is_trained}
|
||||||
model_data = {
|
|
||||||
'model': self.model,
|
|
||||||
'categories': self.categories,
|
|
||||||
'is_trained': self._is_trained
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(model_path, 'wb') as f:
|
with open(model_path, 'wb') as f:
|
||||||
pickle.dump(model_data, f)
|
pickle.dump(model_data, f)
|
||||||
|
|
||||||
def load_model(self, model_path: Path) -> None:
|
def load_model(self, model_path: Path) -> None:
|
||||||
"""Load trained model from disk
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model file
|
|
||||||
"""
|
|
||||||
with open(model_path, 'rb') as f:
|
with open(model_path, 'rb') as f:
|
||||||
model_data = pickle.load(f)
|
model_data = pickle.load(f)
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.categories = model_data['categories']
|
self.categories = model_data['categories']
|
||||||
self._is_trained = model_data['is_trained']
|
self._is_trained = model_data['is_trained']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trained(self) -> bool:
|
def is_trained(self) -> bool:
|
||||||
"""Check if model is trained"""
|
|
||||||
return self._is_trained
|
return self._is_trained
|
||||||
|
|
||||||
|
|
||||||
class DummyMLClassifier:
|
class DummyMLClassifier:
|
||||||
"""Dummy ML classifier for when sklearn is not available"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize dummy classifier"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
def train(self, training_data: List[Tuple[Path, str]]) -> None:
|
||||||
"""Dummy train method"""
|
raise NotImplementedError('ML classification requires scikit-learn. Install with: pip install scikit-learn')
|
||||||
raise NotImplementedError(
|
|
||||||
"ML classification requires scikit-learn. "
|
|
||||||
"Install with: pip install scikit-learn"
|
|
||||||
)
|
|
||||||
|
|
||||||
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
|
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
|
||||||
"""Dummy classify method"""
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_proba(self, path: Path) -> dict[str, float]:
|
def predict_proba(self, path: Path) -> dict[str, float]:
|
||||||
"""Dummy predict_proba method"""
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def save_model(self, model_path: Path) -> None:
|
def save_model(self, model_path: Path) -> None:
|
||||||
"""Dummy save_model method"""
|
raise NotImplementedError('ML classification not available')
|
||||||
raise NotImplementedError("ML classification not available")
|
|
||||||
|
|
||||||
def load_model(self, model_path: Path) -> None:
|
def load_model(self, model_path: Path) -> None:
|
||||||
"""Dummy load_model method"""
|
raise NotImplementedError('ML classification not available')
|
||||||
raise NotImplementedError("ML classification not available")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trained(self) -> bool:
|
def is_trained(self) -> bool:
|
||||||
"""Check if model is trained"""
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def create_ml_classifier() -> MLClassifier | DummyMLClassifier:
|
def create_ml_classifier() -> MLClassifier | DummyMLClassifier:
|
||||||
"""Create ML classifier if sklearn is available, otherwise return dummy
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MLClassifier or DummyMLClassifier
|
|
||||||
"""
|
|
||||||
if SKLEARN_AVAILABLE:
|
if SKLEARN_AVAILABLE:
|
||||||
return MLClassifier()
|
return MLClassifier()
|
||||||
else:
|
else:
|
||||||
return DummyMLClassifier()
|
return DummyMLClassifier()
|
||||||
|
|
||||||
|
def train_from_database(db_connection, min_samples_per_category: int=10) -> MLClassifier | DummyMLClassifier:
|
||||||
def train_from_database(
|
|
||||||
db_connection,
|
|
||||||
min_samples_per_category: int = 10
|
|
||||||
) -> MLClassifier | DummyMLClassifier:
|
|
||||||
"""Train ML classifier from database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_connection: Database connection
|
|
||||||
min_samples_per_category: Minimum samples required per category
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Trained classifier
|
|
||||||
"""
|
|
||||||
classifier = create_ml_classifier()
|
classifier = create_ml_classifier()
|
||||||
|
|
||||||
if isinstance(classifier, DummyMLClassifier):
|
if isinstance(classifier, DummyMLClassifier):
|
||||||
return classifier
|
return classifier
|
||||||
|
|
||||||
# Query classified files from database
|
|
||||||
cursor = db_connection.cursor()
|
cursor = db_connection.cursor()
|
||||||
cursor.execute("""
|
cursor.execute('\n SELECT path, category\n FROM files\n WHERE category IS NOT NULL\n ')
|
||||||
SELECT path, category
|
|
||||||
FROM files
|
|
||||||
WHERE category IS NOT NULL
|
|
||||||
""")
|
|
||||||
|
|
||||||
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
training_data = [(Path(path), category) for path, category in cursor.fetchall()]
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
if not training_data:
|
if not training_data:
|
||||||
return classifier
|
return classifier
|
||||||
|
|
||||||
# Count samples per category
|
|
||||||
category_counts = {}
|
category_counts = {}
|
||||||
for _, category in training_data:
|
for _, category in training_data:
|
||||||
category_counts[category] = category_counts.get(category, 0) + 1
|
category_counts[category] = category_counts.get(category, 0) + 1
|
||||||
|
filtered_data = [(path, category) for path, category in training_data if category_counts[category] >= min_samples_per_category]
|
||||||
# Filter to categories with enough samples
|
|
||||||
filtered_data = [
|
|
||||||
(path, category)
|
|
||||||
for path, category in training_data
|
|
||||||
if category_counts[category] >= min_samples_per_category
|
|
||||||
]
|
|
||||||
|
|
||||||
if filtered_data:
|
if filtered_data:
|
||||||
classifier.train(filtered_data)
|
classifier.train(filtered_data)
|
||||||
|
|
||||||
return classifier
|
return classifier
|
||||||
|
|||||||
@@ -1,282 +1,60 @@
|
|||||||
"""Rule-based classification engine"""
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
|
||||||
from ._protocols import ClassificationRule
|
from ._protocols import ClassificationRule
|
||||||
|
|
||||||
|
|
||||||
class RuleBasedClassifier:
|
class RuleBasedClassifier:
|
||||||
"""Rule-based file classifier using pattern matching"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize rule-based classifier"""
|
|
||||||
self.rules: list[ClassificationRule] = []
|
self.rules: list[ClassificationRule] = []
|
||||||
self._load_default_rules()
|
self._load_default_rules()
|
||||||
|
|
||||||
def _load_default_rules(self):
|
def _load_default_rules(self):
|
||||||
"""Load default classification rules based on ARCHITECTURE.md"""
|
self.add_rule(ClassificationRule(name='maven_cache', category='artifacts/java/maven', patterns=['**/.m2/**', '**/.maven/**', '**/maven-central-cache/**'], priority=10, description='Maven repository and cache'))
|
||||||
|
self.add_rule(ClassificationRule(name='gradle_cache', category='artifacts/java/gradle', patterns=['**/.gradle/**', '**/gradle-cache/**', '**/gradle-build-cache/**'], priority=10, description='Gradle cache and artifacts'))
|
||||||
# Build artifacts and caches
|
self.add_rule(ClassificationRule(name='python_cache', category='cache/pycache', patterns=['**/__pycache__/**', '**/*.pyc', '**/*.pyo'], priority=10, description='Python cache files'))
|
||||||
self.add_rule(ClassificationRule(
|
self.add_rule(ClassificationRule(name='python_artifacts', category='artifacts/python', patterns=['**/pip-cache/**', '**/pypi-cache/**', '**/wheelhouse/**'], priority=10, description='Python package artifacts'))
|
||||||
name="maven_cache",
|
self.add_rule(ClassificationRule(name='node_modules', category='cache/node_modules-archive', patterns=['**/node_modules/**'], priority=10, description='Node.js modules'))
|
||||||
category="artifacts/java/maven",
|
self.add_rule(ClassificationRule(name='node_cache', category='artifacts/node', patterns=['**/.npm/**', '**/npm-registry/**', '**/yarn-cache/**', '**/pnpm-store/**'], priority=10, description='Node.js package managers cache'))
|
||||||
patterns=["**/.m2/**", "**/.maven/**", "**/maven-central-cache/**"],
|
self.add_rule(ClassificationRule(name='go_cache', category='artifacts/go', patterns=['**/goproxy-cache/**', '**/go/pkg/mod/**', '**/go-module-cache/**'], priority=10, description='Go module cache'))
|
||||||
priority=10,
|
self.add_rule(ClassificationRule(name='git_repos', category='development/git-infrastructure', patterns=['**/.git/**', '**/gitea/repositories/**'], priority=15, description='Git repositories and infrastructure'))
|
||||||
description="Maven repository and cache"
|
self.add_rule(ClassificationRule(name='gitea', category='development/gitea', patterns=['**/gitea/**'], priority=12, description='Gitea server data'))
|
||||||
))
|
self.add_rule(ClassificationRule(name='postgresql', category='databases/postgresql', patterns=['**/postgresql/**', '**/postgres/**', '**/*.sql'], priority=10, description='PostgreSQL databases'))
|
||||||
|
self.add_rule(ClassificationRule(name='mysql', category='databases/mysql', patterns=['**/mysql/**', '**/mariadb/**'], priority=10, description='MySQL/MariaDB databases'))
|
||||||
self.add_rule(ClassificationRule(
|
self.add_rule(ClassificationRule(name='mongodb', category='databases/mongodb', patterns=['**/mongodb/**', '**/mongo/**'], priority=10, description='MongoDB databases'))
|
||||||
name="gradle_cache",
|
self.add_rule(ClassificationRule(name='redis', category='databases/redis', patterns=['**/redis/**', '**/*.rdb'], priority=10, description='Redis databases'))
|
||||||
category="artifacts/java/gradle",
|
self.add_rule(ClassificationRule(name='sqlite', category='databases/sqlite', patterns=['**/*.db', '**/*.sqlite', '**/*.sqlite3'], priority=8, description='SQLite databases'))
|
||||||
patterns=["**/.gradle/**", "**/gradle-cache/**", "**/gradle-build-cache/**"],
|
self.add_rule(ClassificationRule(name='llm_models', category='cache/llm-models', patterns=['**/hugging-face/**', '**/huggingface/**', '**/.cache/huggingface/**', '**/models/**/*.bin', '**/models/**/*.onnx', '**/models/**/*.safetensors', '**/llm*/**', '**/openai-cache/**'], priority=12, description='LLM and AI model files'))
|
||||||
priority=10,
|
self.add_rule(ClassificationRule(name='docker_volumes', category='apps/volumes/docker-volumes', patterns=['**/docker/volumes/**', '**/var/lib/docker/volumes/**'], priority=10, description='Docker volumes'))
|
||||||
description="Gradle cache and artifacts"
|
self.add_rule(ClassificationRule(name='app_data', category='apps/volumes/app-data', patterns=['**/app-data/**', '**/application-data/**'], priority=8, description='Application data'))
|
||||||
))
|
self.add_rule(ClassificationRule(name='build_output', category='development/build-tools', patterns=['**/target/**', '**/build/**', '**/dist/**', '**/out/**'], priority=5, description='Build output directories'))
|
||||||
|
self.add_rule(ClassificationRule(name='system_backups', category='backups/system', patterns=['**/backup/**', '**/backups/**', '**/*.bak', '**/*.backup'], priority=10, description='System backups'))
|
||||||
self.add_rule(ClassificationRule(
|
self.add_rule(ClassificationRule(name='database_backups', category='backups/database', patterns=['**/*.sql.gz', '**/*.dump', '**/db-backup/**'], priority=11, description='Database backups'))
|
||||||
name="python_cache",
|
self.add_rule(ClassificationRule(name='archives', category='backups/archive', patterns=['**/*.tar', '**/*.tar.gz', '**/*.tgz', '**/*.zip', '**/*.7z'], priority=5, description='Archive files'))
|
||||||
category="cache/pycache",
|
|
||||||
patterns=["**/__pycache__/**", "**/*.pyc", "**/*.pyo"],
|
|
||||||
priority=10,
|
|
||||||
description="Python cache files"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="python_artifacts",
|
|
||||||
category="artifacts/python",
|
|
||||||
patterns=["**/pip-cache/**", "**/pypi-cache/**", "**/wheelhouse/**"],
|
|
||||||
priority=10,
|
|
||||||
description="Python package artifacts"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="node_modules",
|
|
||||||
category="cache/node_modules-archive",
|
|
||||||
patterns=["**/node_modules/**"],
|
|
||||||
priority=10,
|
|
||||||
description="Node.js modules"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="node_cache",
|
|
||||||
category="artifacts/node",
|
|
||||||
patterns=["**/.npm/**", "**/npm-registry/**", "**/yarn-cache/**", "**/pnpm-store/**"],
|
|
||||||
priority=10,
|
|
||||||
description="Node.js package managers cache"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="go_cache",
|
|
||||||
category="artifacts/go",
|
|
||||||
patterns=["**/goproxy-cache/**", "**/go/pkg/mod/**", "**/go-module-cache/**"],
|
|
||||||
priority=10,
|
|
||||||
description="Go module cache"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Version control
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="git_repos",
|
|
||||||
category="development/git-infrastructure",
|
|
||||||
patterns=["**/.git/**", "**/gitea/repositories/**"],
|
|
||||||
priority=15,
|
|
||||||
description="Git repositories and infrastructure"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="gitea",
|
|
||||||
category="development/gitea",
|
|
||||||
patterns=["**/gitea/**"],
|
|
||||||
priority=12,
|
|
||||||
description="Gitea server data"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Databases
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="postgresql",
|
|
||||||
category="databases/postgresql",
|
|
||||||
patterns=["**/postgresql/**", "**/postgres/**", "**/*.sql"],
|
|
||||||
priority=10,
|
|
||||||
description="PostgreSQL databases"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="mysql",
|
|
||||||
category="databases/mysql",
|
|
||||||
patterns=["**/mysql/**", "**/mariadb/**"],
|
|
||||||
priority=10,
|
|
||||||
description="MySQL/MariaDB databases"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="mongodb",
|
|
||||||
category="databases/mongodb",
|
|
||||||
patterns=["**/mongodb/**", "**/mongo/**"],
|
|
||||||
priority=10,
|
|
||||||
description="MongoDB databases"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="redis",
|
|
||||||
category="databases/redis",
|
|
||||||
patterns=["**/redis/**", "**/*.rdb"],
|
|
||||||
priority=10,
|
|
||||||
description="Redis databases"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="sqlite",
|
|
||||||
category="databases/sqlite",
|
|
||||||
patterns=["**/*.db", "**/*.sqlite", "**/*.sqlite3"],
|
|
||||||
priority=8,
|
|
||||||
description="SQLite databases"
|
|
||||||
))
|
|
||||||
|
|
||||||
# LLM and AI models
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="llm_models",
|
|
||||||
category="cache/llm-models",
|
|
||||||
patterns=[
|
|
||||||
"**/hugging-face/**",
|
|
||||||
"**/huggingface/**",
|
|
||||||
"**/.cache/huggingface/**",
|
|
||||||
"**/models/**/*.bin",
|
|
||||||
"**/models/**/*.onnx",
|
|
||||||
"**/models/**/*.safetensors",
|
|
||||||
"**/llm*/**",
|
|
||||||
"**/openai-cache/**"
|
|
||||||
],
|
|
||||||
priority=12,
|
|
||||||
description="LLM and AI model files"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Docker and containers
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="docker_volumes",
|
|
||||||
category="apps/volumes/docker-volumes",
|
|
||||||
patterns=["**/docker/volumes/**", "**/var/lib/docker/volumes/**"],
|
|
||||||
priority=10,
|
|
||||||
description="Docker volumes"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="app_data",
|
|
||||||
category="apps/volumes/app-data",
|
|
||||||
patterns=["**/app-data/**", "**/application-data/**"],
|
|
||||||
priority=8,
|
|
||||||
description="Application data"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Build outputs
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="build_output",
|
|
||||||
category="development/build-tools",
|
|
||||||
patterns=["**/target/**", "**/build/**", "**/dist/**", "**/out/**"],
|
|
||||||
priority=5,
|
|
||||||
description="Build output directories"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Backups
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="system_backups",
|
|
||||||
category="backups/system",
|
|
||||||
patterns=["**/backup/**", "**/backups/**", "**/*.bak", "**/*.backup"],
|
|
||||||
priority=10,
|
|
||||||
description="System backups"
|
|
||||||
))
|
|
||||||
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="database_backups",
|
|
||||||
category="backups/database",
|
|
||||||
patterns=["**/*.sql.gz", "**/*.dump", "**/db-backup/**"],
|
|
||||||
priority=11,
|
|
||||||
description="Database backups"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Archives
|
|
||||||
self.add_rule(ClassificationRule(
|
|
||||||
name="archives",
|
|
||||||
category="backups/archive",
|
|
||||||
patterns=["**/*.tar", "**/*.tar.gz", "**/*.tgz", "**/*.zip", "**/*.7z"],
|
|
||||||
priority=5,
|
|
||||||
description="Archive files"
|
|
||||||
))
|
|
||||||
|
|
||||||
def add_rule(self, rule: ClassificationRule) -> None:
|
def add_rule(self, rule: ClassificationRule) -> None:
|
||||||
"""Add a classification rule
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rule: Rule to add
|
|
||||||
"""
|
|
||||||
self.rules.append(rule)
|
self.rules.append(rule)
|
||||||
# Sort rules by priority (higher priority first)
|
|
||||||
self.rules.sort(key=lambda r: r.priority, reverse=True)
|
self.rules.sort(key=lambda r: r.priority, reverse=True)
|
||||||
|
|
||||||
def remove_rule(self, rule_name: str) -> None:
|
def remove_rule(self, rule_name: str) -> None:
|
||||||
"""Remove a rule by name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rule_name: Name of rule to remove
|
|
||||||
"""
|
|
||||||
self.rules = [r for r in self.rules if r.name != rule_name]
|
self.rules = [r for r in self.rules if r.name != rule_name]
|
||||||
|
|
||||||
def match_path(self, path: Path) -> Optional[str]:
|
def match_path(self, path: Path) -> Optional[str]:
|
||||||
"""Match path against rules
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to match
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Category name or None if no match
|
|
||||||
"""
|
|
||||||
path_str = str(path)
|
path_str = str(path)
|
||||||
|
|
||||||
# Try to match each rule in priority order
|
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
for pattern in rule.patterns:
|
for pattern in rule.patterns:
|
||||||
if fnmatch.fnmatch(path_str, pattern):
|
if fnmatch.fnmatch(path_str, pattern):
|
||||||
return rule.category
|
return rule.category
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def classify(self, path: Path, file_type: Optional[str] = None) -> Optional[str]:
|
def classify(self, path: Path, file_type: Optional[str]=None) -> Optional[str]:
|
||||||
"""Classify a file path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to classify
|
|
||||||
file_type: Optional file type hint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Category name or None if no match
|
|
||||||
"""
|
|
||||||
return self.match_path(path)
|
return self.match_path(path)
|
||||||
|
|
||||||
def get_category_rules(self, category: str) -> list[ClassificationRule]:
|
def get_category_rules(self, category: str) -> list[ClassificationRule]:
|
||||||
"""Get all rules for a category
|
|
||||||
|
|
||||||
Args:
|
|
||||||
category: Category name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of rules for the category
|
|
||||||
"""
|
|
||||||
return [r for r in self.rules if r.category == category]
|
return [r for r in self.rules if r.category == category]
|
||||||
|
|
||||||
def get_all_categories(self) -> set[str]:
|
def get_all_categories(self) -> set[str]:
|
||||||
"""Get all defined categories
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of category names
|
|
||||||
"""
|
|
||||||
return {r.category for r in self.rules}
|
return {r.category for r in self.rules}
|
||||||
|
|
||||||
def get_rules_by_priority(self, min_priority: int = 0) -> list[ClassificationRule]:
|
def get_rules_by_priority(self, min_priority: int=0) -> list[ClassificationRule]:
|
||||||
"""Get rules above a minimum priority
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_priority: Minimum priority threshold
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of rules with priority >= min_priority
|
|
||||||
"""
|
|
||||||
return [r for r in self.rules if r.priority >= min_priority]
|
return [r for r in self.rules if r.priority >= min_priority]
|
||||||
|
|||||||
Reference in New Issue
Block a user