From 9bde1395bfae2d3dab8237ce1f7637ff76bc5028 Mon Sep 17 00:00:00 2001 From: bpmcdevitt Date: Thu, 17 Jul 2025 18:58:47 -0500 Subject: [PATCH] Optimize performance and migrate to Celery-based scheduling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces major performance improvements and migrates from custom job scheduling to Celery Beat for better reliability and scalability. ### 🚀 Performance Optimizations **CVE2CAPEC Client Performance (Fixed startup blocking)** - Implement lazy loading with 24-hour cache for CVE2CAPEC mappings - Add background task for CVE2CAPEC sync (data_sync_tasks.sync_cve2capec) - Remove blocking data fetch during client initialization - API endpoint: POST /api/sync-cve2capec **ExploitDB Client Performance (Fixed webapp request blocking)** - Implement global file index cache to prevent rebuilding on every request - Add lazy loading with 24-hour cache expiry for 46K+ exploit index - Background task for index building (data_sync_tasks.build_exploitdb_index) - API endpoint: POST /api/build-exploitdb-index ### 🔄 Celery Migration & Scheduling **Celery Beat Integration** - Migrate from custom job scheduler to Celery Beat for reliability - Remove 'finetuned' LLM provider (logic moved to ollama container) - Optimized daily workflow with proper timing and dependencies **New Celery Tasks Structure** - tasks/bulk_tasks.py - NVD bulk processing and SIGMA generation - tasks/data_sync_tasks.py - All data synchronization tasks - tasks/maintenance_tasks.py - System maintenance and cleanup - tasks/sigma_tasks.py - SIGMA rule generation tasks **Daily Schedule (Optimized)** ``` 1:00 AM → Weekly cleanup (Sundays) 1:30 AM → Daily result cleanup 2:00 AM → NVD incremental update 3:00 AM → CISA KEV sync 3:15 AM → Nomi-sec PoC sync 3:30 AM → GitHub PoC sync 3:45 AM → ExploitDB sync 4:00 AM → CVE2CAPEC MITRE ATT&CK sync 4:15 AM → ExploitDB index rebuild 5:00 AM → Reference content sync 8:00 AM → SIGMA rule generation 9:00 AM → LLM-enhanced SIGMA generation Every 15min → Health checks ``` ### 🐳 Docker & Infrastructure **Enhanced Docker Setup** - Ollama setup with integrated SIGMA model creation (setup_ollama_with_sigma.py) - Initial database population check and trigger (initial_setup.py) - Proper service dependencies and health checks - Remove manual post-rebuild script requirements **Service Architecture** - Celery worker with 4-queue system (default, bulk_processing, sigma_generation, data_sync) - Flower monitoring dashboard (localhost:5555) - Redis as message broker and result backend ### 🎯 API Improvements **Background Task Endpoints** - GitHub PoC sync now uses Celery (was blocking backend) - All sync operations return task IDs and monitoring URLs - Consistent error handling and progress tracking **New Endpoints** - POST /api/sync-cve2capec - CVE2CAPEC mapping sync - POST /api/build-exploitdb-index - ExploitDB index rebuild ### 📁 Cleanup **Removed Files** - fix_sigma_model.sh (replaced by setup_ollama_with_sigma.py) - Various test_* and debug_* files no longer needed - Old training scripts related to removed 'finetuned' provider - Utility scripts replaced by Docker services ### 🔧 Configuration **Key Files Added/Modified** - backend/celery_config.py - Complete Celery configuration - backend/initial_setup.py - First-boot database population - backend/setup_ollama_with_sigma.py - Integrated Ollama setup - CLAUDE.md - Project documentation and development guide 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- CLAUDE.md | 282 +++++++++ backend/celery_config.py | 222 +++++++ backend/cve2capec_client.py | 25 +- backend/exploitdb_client_local.py | 51 +- backend/initial_setup.py | 171 +++++ backend/llm_client.py | 31 +- backend/main.py | 566 ++++++----------- backend/setup_ollama_with_sigma.py | 256 ++++++++ backend/tasks/__init__.py | 3 + backend/tasks/bulk_tasks.py | 235 +++++++ backend/tasks/data_sync_tasks.py | 959 +++++++++++++++++++++++++++++ backend/tasks/maintenance_tasks.py | 437 +++++++++++++ backend/tasks/sigma_tasks.py | 409 ++++++++++++ docker-compose.yml | 116 +++- frontend/src/App.js | 332 ++-------- 15 files changed, 3402 insertions(+), 693 deletions(-) create mode 100644 CLAUDE.md create mode 100644 backend/celery_config.py create mode 100644 backend/initial_setup.py create mode 100644 backend/setup_ollama_with_sigma.py create mode 100644 backend/tasks/__init__.py create mode 100644 backend/tasks/bulk_tasks.py create mode 100644 backend/tasks/data_sync_tasks.py create mode 100644 backend/tasks/maintenance_tasks.py create mode 100644 backend/tasks/sigma_tasks.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..3b7eb07 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,282 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is an enhanced CVE-SIGMA Auto Generator that automatically processes comprehensive CVE data and generates SIGMA rules for threat detection. The application now supports: + +1. **Bulk NVD Data Processing**: Downloads and processes complete NVD JSON datasets (2002-2025) +2. **nomi-sec PoC Integration**: Uses curated PoC data from github.com/nomi-sec/PoC-in-GitHub +3. **Enhanced SIGMA Rule Generation**: Creates intelligent rules based on real exploit indicators +4. **Comprehensive Database Seeding**: Supports both bulk and incremental data updates + +## Architecture + +- **Backend**: FastAPI with SQLAlchemy ORM (`backend/main.py`) +- **Frontend**: React with Tailwind CSS (`frontend/src/App.js`) +- **Database**: PostgreSQL with enhanced schema: + - `cves`: CVE information with PoC metadata and bulk processing fields + - `sigma_rules`: Enhanced SIGMA rules with quality scoring and nomi-sec data + - `rule_templates`: Template patterns for rule generation + - `bulk_processing_jobs`: Job tracking for bulk operations +- **Data Processing**: + - `nvd_bulk_processor.py`: NVD JSON dataset downloader and processor + - `nomi_sec_client.py`: nomi-sec PoC-in-GitHub API integration + - `enhanced_sigma_generator.py`: Advanced SIGMA rule generation + - `bulk_seeder.py`: Coordinated bulk seeding operations +- **Cache**: Redis (optional) +- **Deployment**: Docker Compose orchestration + +## Common Development Commands + +### Quick Start +```bash +# Recommended quick start +chmod +x start.sh +./start.sh + +# Or using Make +make start +``` + +### Build and Run +```bash +# Build and start all services +docker-compose up -d --build + +# Start individual services +docker-compose up -d db redis # Database and cache only +docker-compose up -d backend # Backend API +docker-compose up -d frontend # React frontend +``` + +### Development Mode +```bash +# Using Make +make dev + +# Or manually +docker-compose up -d db redis +cd backend && pip install -r requirements.txt && uvicorn main:app --reload +cd frontend && npm install && npm start +``` + +### Bulk Processing Commands +```bash +# Run bulk seeding standalone +cd backend && python bulk_seeder.py + +# Bulk seed specific year range +cd backend && python -c " +import asyncio +from bulk_seeder import BulkSeeder +from main import SessionLocal +seeder = BulkSeeder(SessionLocal()) +asyncio.run(seeder.full_bulk_seed(start_year=2020, end_year=2025)) +" + +# Incremental update only +cd backend && python -c " +import asyncio +from bulk_seeder import BulkSeeder +from main import SessionLocal +seeder = BulkSeeder(SessionLocal()) +asyncio.run(seeder.incremental_update()) +" +``` + +### Frontend Commands +```bash +cd frontend +npm install # Install dependencies +npm start # Development server (port 3000) +npm run build # Production build +npm test # Run tests +``` + +### Backend Commands +```bash +cd backend +pip install -r requirements.txt +uvicorn main:app --reload # Development server (port 8000) +uvicorn main:app --host 0.0.0.0 --port 8000 # Production server +``` + +### Database Operations +```bash +# Connect to database +docker-compose exec db psql -U cve_user -d cve_sigma_db + +# View logs +docker-compose logs -f backend +docker-compose logs -f frontend +``` + +### Other Make Commands +```bash +make stop # Stop all services +make restart # Restart all services +make logs # View application logs +make clean # Clean up containers and volumes +make setup # Initial setup (creates .env from .env.example) +``` + +## Key Configuration + +### Environment Variables (.env) +- `NVD_API_KEY`: Optional NVD API key for higher rate limits (5→50 requests/30s) +- `GITHUB_TOKEN`: Optional GitHub token for exploit analysis (enhances rule generation) +- `OPENAI_API_KEY`: Optional OpenAI API key for AI-enhanced SIGMA rule generation +- `ANTHROPIC_API_KEY`: Optional Anthropic API key for AI-enhanced SIGMA rule generation +- `OLLAMA_BASE_URL`: Optional Ollama base URL for local model AI-enhanced SIGMA rule generation +- `LLM_PROVIDER`: Optional LLM provider selection (openai, anthropic, ollama) +- `LLM_MODEL`: Optional LLM model selection (provider-specific) +- `DATABASE_URL`: PostgreSQL connection string +- `REACT_APP_API_URL`: Backend API URL for frontend + +### Service URLs +- Frontend: http://localhost:3000 +- Backend API: http://localhost:8000 +- API Documentation: http://localhost:8000/docs +- Database: localhost:5432 +- Redis: localhost:6379 + +### Enhanced API Endpoints + +#### Bulk Processing +- `POST /api/bulk-seed` - Start complete bulk seeding (NVD + nomi-sec) +- `POST /api/incremental-update` - Update with NVD modified/recent feeds +- `POST /api/sync-nomi-sec` - Synchronize nomi-sec PoC data +- `POST /api/regenerate-rules` - Regenerate SIGMA rules with enhanced data +- `GET /api/bulk-jobs` - Get bulk processing job status +- `GET /api/bulk-status` - Get comprehensive system status +- `GET /api/poc-stats` - Get PoC-related statistics + +#### Enhanced Data Access +- `GET /api/stats` - Enhanced statistics with PoC coverage +- `GET /api/claude-status` - Get Claude API availability status +- All existing CVE and SIGMA rule endpoints now include enhanced data fields + +#### LLM-Enhanced Rule Generation +- `POST /api/llm-enhanced-rules` - Generate SIGMA rules using LLM AI analysis (supports multiple providers) +- `GET /api/llm-status` - Check LLM API availability and configuration for all providers +- `POST /api/llm-switch` - Switch between LLM providers and models + +## Code Architecture Details + +### Enhanced Backend Structure +- **main.py**: Core FastAPI application with enhanced endpoints +- **nvd_bulk_processor.py**: NVD JSON dataset downloader and processor +- **nomi_sec_client.py**: nomi-sec PoC-in-GitHub API integration +- **enhanced_sigma_generator.py**: Advanced SIGMA rule generation with PoC data +- **llm_client.py**: Multi-provider LLM integration using LangChain for AI-enhanced rule generation +- **bulk_seeder.py**: Coordinated bulk processing operations + +### Database Models (Enhanced) +- **CVE**: Enhanced with `poc_count`, `poc_data`, `bulk_processed`, `data_source` +- **SigmaRule**: Enhanced with `poc_source`, `poc_quality_score`, `nomi_sec_data` +- **RuleTemplate**: Template patterns for rule generation +- **BulkProcessingJob**: Job tracking for bulk operations + +### Frontend Structure (Enhanced) +- **Four Main Tabs**: Dashboard, CVEs, SIGMA Rules, Bulk Jobs +- **Enhanced Dashboard**: PoC coverage statistics, bulk processing controls +- **Bulk Jobs Tab**: Real-time job monitoring and system status +- **Enhanced CVE/Rule Display**: PoC quality indicators, exploit-based tagging + +### Data Processing Flow +1. **Bulk Seeding**: NVD JSON downloads → Database storage → nomi-sec PoC sync → Enhanced rule generation +2. **Incremental Updates**: NVD modified feeds → Update existing data → Sync new PoCs +3. **Rule Enhancement**: PoC analysis → Indicator extraction → Template selection → Enhanced SIGMA rule +4. **LLM-Enhanced Generation**: PoC content analysis → Multi-provider LLM processing → Advanced SIGMA rule creation + +## Development Notes + +### Enhanced Rule Generation Logic +The application now uses an advanced rule generation process: +1. **CVE Analysis**: Extract metadata from NVD bulk data +2. **PoC Quality Assessment**: nomi-sec PoC analysis with star count, recency, quality tiers +3. **Advanced Indicator Extraction**: Processes, files, network, registry, commands from PoC repositories +4. **Template Selection**: Smart template matching based on PoC indicators and CVE characteristics +5. **Enhanced Rule Population**: Incorporate real exploit indicators with quality scoring +6. **MITRE ATT&CK Mapping**: Automatic technique identification based on indicators +7. **LLM AI Enhancement**: Optional multi-provider LLM integration for intelligent rule generation from PoC code analysis + +### Quality Tiers +- **Excellent** (80+ points): High star count, recent updates, detailed descriptions +- **Good** (60-79 points): Moderate quality indicators +- **Fair** (40-59 points): Basic PoC with some quality indicators +- **Poor** (20-39 points): Minimal quality indicators +- **Very Poor** (<20 points): Low-quality PoCs + +### Multi-Provider LLM Integration Features +- **Multiple LLM Providers**: Support for OpenAI, Anthropic, and Ollama (local models) +- **Dynamic Provider Switching**: Switch between providers and models through UI or API +- **Intelligent Code Analysis**: LLMs analyze actual exploit code from PoC repositories +- **Advanced Rule Generation**: Creates sophisticated SIGMA rules with proper syntax and logic +- **Contextual Understanding**: Interprets CVE descriptions and maps them to appropriate detection patterns +- **Automatic Validation**: Generated rules are validated for SIGMA syntax compliance +- **Fallback Mechanism**: Automatically falls back to template-based generation if LLM is unavailable +- **Enhanced Metadata**: Rules include generation method tracking for quality assessment +- **LangChain Integration**: Uses LangChain for robust LLM integration and prompt management + +### Supported LLM Providers and Models + +#### OpenAI +- **API Key**: Set `OPENAI_API_KEY` environment variable +- **Supported Models**: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo +- **Default Model**: gpt-4o-mini +- **Rate Limits**: Based on OpenAI API limits + +#### Anthropic +- **API Key**: Set `ANTHROPIC_API_KEY` environment variable +- **Supported Models**: claude-3-5-sonnet-20241022, claude-3-haiku-20240307, claude-3-opus-20240229 +- **Default Model**: claude-3-5-sonnet-20241022 +- **Rate Limits**: Based on Anthropic API limits + +#### Ollama (Local Models) +- **Setup**: Install Ollama locally and set `OLLAMA_BASE_URL` (default: http://localhost:11434) +- **Supported Models**: llama3.2, codellama, mistral, llama2 (any Ollama-compatible model) +- **Default Model**: llama3.2 +- **Rate Limits**: No external API limits (local processing) + +### Testing and Validation +- **Frontend tests**: `npm test` (in frontend directory) +- **Backend testing**: Use standalone scripts for bulk operations +- **API testing**: Use `/docs` endpoint for Swagger UI +- **Bulk Processing**: Monitor via `/api/bulk-jobs` and frontend Bulk Jobs tab + +### Security Considerations +- **API Keys**: Store NVD and GitHub tokens in environment variables +- **PoC Analysis**: Automated analysis of curated PoC repositories (safer than raw GitHub search) +- **Rate Limiting**: Built-in rate limiting for external APIs +- **Data Validation**: Enhanced validation for bulk data processing +- **Audit Trail**: Job tracking for all bulk operations + +## Troubleshooting + +### Common Issues +- **Bulk Processing Failures**: Check `/api/bulk-jobs` for detailed error messages +- **NVD Data Download Issues**: Verify NVD API key and network connectivity +- **nomi-sec API Timeouts**: Built-in retry logic, check network connectivity +- **Frontend build errors**: Run `npm install` in frontend directory +- **Database schema changes**: Restart backend to auto-create new tables +- **Memory issues during bulk processing**: Monitor system resources, consider smaller batch sizes + +### Enhanced Rate Limits +- **NVD API**: 5 requests/30s (no key) → 50 requests/30s (with key) +- **nomi-sec API**: 1 request/second (built-in rate limiting) +- **GitHub API** (fallback): 60 requests/hour (no token) → 5000 requests/hour (with token) + +### Performance Optimization +- **Bulk Processing**: Start with recent years (2020+) for faster initial setup +- **PoC Sync**: Use smaller batch sizes (50) for better stability +- **Rule Generation**: Monitor quality scores to prioritize high-value PoCs +- **Database**: Ensure proper indexing on CVE ID and PoC fields + +### Monitoring +- **Frontend**: Use Bulk Jobs tab for real-time progress monitoring +- **Backend logs**: `docker-compose logs -f backend` +- **Job status**: Check `/api/bulk-status` for comprehensive system health +- **Database**: Monitor PoC coverage percentage and rule enhancement progress \ No newline at end of file diff --git a/backend/celery_config.py b/backend/celery_config.py new file mode 100644 index 0000000..664802a --- /dev/null +++ b/backend/celery_config.py @@ -0,0 +1,222 @@ +""" +Celery configuration for the Auto SIGMA Rule Generator +""" +import os +from celery import Celery +from celery.schedules import crontab +from kombu import Queue + +# Celery configuration +broker_url = os.getenv('CELERY_BROKER_URL', 'redis://redis:6379/0') +result_backend = os.getenv('CELERY_RESULT_BACKEND', 'redis://redis:6379/0') + +# Create Celery app +celery_app = Celery( + 'sigma_generator', + broker=broker_url, + backend=result_backend, + include=[ + 'tasks.bulk_tasks', + 'tasks.sigma_tasks', + 'tasks.data_sync_tasks', + 'tasks.maintenance_tasks' + ] +) + +# Celery configuration +celery_app.conf.update( + # Serialization + task_serializer='json', + accept_content=['json'], + result_serializer='json', + + # Timezone + timezone='UTC', + enable_utc=True, + + # Task tracking + task_track_started=True, + task_send_sent_event=True, + + # Result backend settings + result_expires=3600, # Results expire after 1 hour + result_backend_transport_options={ + 'master_name': 'mymaster', + 'visibility_timeout': 3600, + }, + + # Worker settings + worker_prefetch_multiplier=1, + task_acks_late=True, + worker_max_tasks_per_child=1000, + + # Task routes - different queues for different types of tasks + task_routes={ + 'tasks.bulk_tasks.*': {'queue': 'bulk_processing'}, + 'tasks.sigma_tasks.*': {'queue': 'sigma_generation'}, + 'tasks.data_sync_tasks.*': {'queue': 'data_sync'}, + }, + + # Queue definitions + task_default_queue='default', + task_queues=( + Queue('default', routing_key='default'), + Queue('bulk_processing', routing_key='bulk_processing'), + Queue('sigma_generation', routing_key='sigma_generation'), + Queue('data_sync', routing_key='data_sync'), + ), + + # Retry settings + task_default_retry_delay=60, # 1 minute + task_max_retries=3, + + # Monitoring + worker_send_task_events=True, + + # Optimized Beat schedule for daily workflow + # WORKFLOW: NVD incremental -> Exploit syncs -> Reference sync -> SIGMA rules + beat_schedule={ + # STEP 1: NVD Incremental Update - Daily at 2:00 AM + # This runs first to get the latest CVE data from NVD + 'daily-nvd-incremental-update': { + 'task': 'bulk_tasks.incremental_update_task', + 'schedule': crontab(minute=0, hour=2), # Daily at 2:00 AM + 'options': {'queue': 'bulk_processing'}, + 'kwargs': {'batch_size': 100, 'skip_nvd': False, 'skip_nomi_sec': True} + }, + + # STEP 2: Exploit Data Syncing - Daily starting at 3:00 AM + # These run in parallel but start at different times to avoid conflicts + + # CISA KEV Sync - Daily at 3:00 AM (15 minutes after NVD) + 'daily-cisa-kev-sync': { + 'task': 'data_sync_tasks.sync_cisa_kev', + 'schedule': crontab(minute=0, hour=3), # Daily at 3:00 AM + 'options': {'queue': 'data_sync'}, + 'kwargs': {'batch_size': 100} + }, + + # Nomi-sec PoC Sync - Daily at 3:15 AM + 'daily-nomi-sec-sync': { + 'task': 'data_sync_tasks.sync_nomi_sec', + 'schedule': crontab(minute=15, hour=3), # Daily at 3:15 AM + 'options': {'queue': 'data_sync'}, + 'kwargs': {'batch_size': 100} + }, + + # GitHub PoC Sync - Daily at 3:30 AM + 'daily-github-poc-sync': { + 'task': 'data_sync_tasks.sync_github_poc', + 'schedule': crontab(minute=30, hour=3), # Daily at 3:30 AM + 'options': {'queue': 'data_sync'}, + 'kwargs': {'batch_size': 50} + }, + + # ExploitDB Sync - Daily at 3:45 AM + 'daily-exploitdb-sync': { + 'task': 'data_sync_tasks.sync_exploitdb', + 'schedule': crontab(minute=45, hour=3), # Daily at 3:45 AM + 'options': {'queue': 'data_sync'}, + 'kwargs': {'batch_size': 30} + }, + + # CVE2CAPEC MITRE ATT&CK Mapping Sync - Daily at 4:00 AM + 'daily-cve2capec-sync': { + 'task': 'data_sync_tasks.sync_cve2capec', + 'schedule': crontab(minute=0, hour=4), # Daily at 4:00 AM + 'options': {'queue': 'data_sync'}, + 'kwargs': {'force_refresh': False} # Only refresh if cache is stale + }, + + # ExploitDB Index Rebuild - Daily at 4:15 AM + 'daily-exploitdb-index-build': { + 'task': 'data_sync_tasks.build_exploitdb_index', + 'schedule': crontab(minute=15, hour=4), # Daily at 4:15 AM + 'options': {'queue': 'data_sync'} + }, + + # STEP 3: Reference Content Sync - Daily at 5:00 AM + # This is the longest-running task, starts after exploit syncs have time to complete + 'daily-reference-content-sync': { + 'task': 'data_sync_tasks.sync_reference_content', + 'schedule': crontab(minute=0, hour=5), # Daily at 5:00 AM + 'options': {'queue': 'data_sync'}, + 'kwargs': {'batch_size': 30, 'max_cves': 200, 'force_resync': False} + }, + + # STEP 4: SIGMA Rule Generation - Daily at 8:00 AM + # This runs LAST after all other daily data sync jobs + 'daily-sigma-rule-generation': { + 'task': 'bulk_tasks.generate_enhanced_sigma_rules', + 'schedule': crontab(minute=0, hour=8), # Daily at 8:00 AM + 'options': {'queue': 'sigma_generation'} + }, + + # LLM-Enhanced SIGMA Rule Generation - Daily at 9:00 AM + # Additional LLM-based rule generation after standard rules + 'daily-llm-sigma-generation': { + 'task': 'sigma_tasks.generate_enhanced_rules', + 'schedule': crontab(minute=0, hour=9), # Daily at 9:00 AM + 'options': {'queue': 'sigma_generation'}, + 'kwargs': {'cve_ids': None} # Process all CVEs with PoCs + }, + + # MAINTENANCE TASKS + + # Database Cleanup - Weekly on Sunday at 1:00 AM (before daily workflow) + 'weekly-database-cleanup': { + 'task': 'tasks.maintenance_tasks.database_cleanup_comprehensive', + 'schedule': crontab(minute=0, hour=1, day_of_week=0), # Sunday at 1:00 AM + 'options': {'queue': 'default'}, + 'kwargs': {'days_to_keep': 30, 'cleanup_failed_jobs': True, 'cleanup_logs': True} + }, + + # Health Check - Every 15 minutes + 'health-check-detailed': { + 'task': 'tasks.maintenance_tasks.health_check_detailed', + 'schedule': crontab(minute='*/15'), # Every 15 minutes + 'options': {'queue': 'default'} + }, + + # Celery result cleanup - Daily at 1:30 AM + 'daily-cleanup-old-results': { + 'task': 'tasks.maintenance_tasks.cleanup_old_results', + 'schedule': crontab(minute=30, hour=1), # Daily at 1:30 AM + 'options': {'queue': 'default'} + }, + }, +) + +# Configure logging +celery_app.conf.update( + worker_log_format='[%(asctime)s: %(levelname)s/%(processName)s] %(message)s', + worker_task_log_format='[%(asctime)s: %(levelname)s/%(processName)s][%(task_name)s(%(task_id)s)] %(message)s', +) + +# Database session configuration for tasks +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# Database configuration +DATABASE_URL = os.getenv('DATABASE_URL', 'postgresql://cve_user:cve_password@db:5432/cve_sigma_db') + +# Create engine and session factory +engine = create_engine(DATABASE_URL) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def get_db_session(): + """Get database session for tasks""" + return SessionLocal() + +# Import all task modules to register them +def register_tasks(): + """Register all task modules""" + try: + from tasks import bulk_tasks, sigma_tasks, data_sync_tasks, maintenance_tasks + print("All task modules registered successfully") + except ImportError as e: + print(f"Warning: Could not import some task modules: {e}") + +# Auto-register tasks when module is imported +if __name__ != "__main__": + register_tasks() diff --git a/backend/cve2capec_client.py b/backend/cve2capec_client.py index 5848de7..ee8847c 100644 --- a/backend/cve2capec_client.py +++ b/backend/cve2capec_client.py @@ -15,18 +15,20 @@ logger = logging.getLogger(__name__) class CVE2CAPECClient: """Client for accessing CVE to MITRE ATT&CK technique mappings.""" - def __init__(self): + def __init__(self, lazy_load: bool = True): self.base_url = "https://raw.githubusercontent.com/Galeax/CVE2CAPEC/main" self.cache_file = "/tmp/cve2capec_cache.json" self.cache_expiry_hours = 24 # Cache for 24 hours self.cve_mappings = {} self.technique_names = {} # Map technique IDs to names + self._data_loaded = False - # Load cached data if available - self._load_cache() - - # Load MITRE ATT&CK technique names + # Load MITRE ATT&CK technique names (lightweight) self._load_technique_names() + + # Only load cached data if not lazy loading + if not lazy_load: + self._load_cache() def _load_cache(self): """Load cached CVE mappings if they exist and are fresh.""" @@ -39,15 +41,18 @@ class CVE2CAPECClient: cache_time = datetime.fromisoformat(cache_data.get('timestamp', '2000-01-01')) if datetime.now() - cache_time < timedelta(hours=self.cache_expiry_hours): self.cve_mappings = cache_data.get('mappings', {}) + self._data_loaded = True logger.info(f"Loaded {len(self.cve_mappings)} CVE mappings from cache") return # Cache is stale or doesn't exist, fetch fresh data self._fetch_fresh_data() + self._data_loaded = True except Exception as e: logger.error(f"Error loading CVE2CAPEC cache: {e}") self._fetch_fresh_data() + self._data_loaded = True def _fetch_fresh_data(self): """Fetch fresh CVE mappings from the repository.""" @@ -133,6 +138,12 @@ class CVE2CAPECClient: # Continue with empty mappings if fetch fails self.cve_mappings = {} + def _ensure_data_loaded(self): + """Ensure CVE mappings are loaded, loading from cache if needed.""" + if not self._data_loaded: + logger.info("CVE2CAPEC data not loaded, loading from cache...") + self._load_cache() + def _load_technique_names(self): """Load MITRE ATT&CK technique names for better rule descriptions.""" # Common MITRE ATT&CK techniques and their names @@ -350,6 +361,7 @@ class CVE2CAPECClient: def get_mitre_techniques_for_cve(self, cve_id: str) -> List[str]: """Get MITRE ATT&CK techniques for a given CVE ID.""" try: + self._ensure_data_loaded() cve_data = self.cve_mappings.get(cve_id, {}) techniques = cve_data.get('TECHNIQUES', []) @@ -374,6 +386,7 @@ class CVE2CAPECClient: def get_cwe_for_cve(self, cve_id: str) -> List[str]: """Get CWE codes for a given CVE ID.""" try: + self._ensure_data_loaded() cve_data = self.cve_mappings.get(cve_id, {}) cwes = cve_data.get('CWE', []) @@ -392,6 +405,7 @@ class CVE2CAPECClient: def get_capec_for_cve(self, cve_id: str) -> List[str]: """Get CAPEC codes for a given CVE ID.""" try: + self._ensure_data_loaded() cve_data = self.cve_mappings.get(cve_id, {}) capecs = cve_data.get('CAPEC', []) @@ -430,6 +444,7 @@ class CVE2CAPECClient: def get_stats(self) -> Dict: """Get statistics about the CVE2CAPEC dataset.""" + self._ensure_data_loaded() total_cves = len(self.cve_mappings) cves_with_techniques = len([cve for cve, data in self.cve_mappings.items() if data.get('TECHNIQUES')]) diff --git a/backend/exploitdb_client_local.py b/backend/exploitdb_client_local.py index 91ef50f..b904bb0 100644 --- a/backend/exploitdb_client_local.py +++ b/backend/exploitdb_client_local.py @@ -7,7 +7,7 @@ import os import re import json import logging -from datetime import datetime +from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple from sqlalchemy.orm import Session from pathlib import Path @@ -16,10 +16,15 @@ from pathlib import Path logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Global cache for file index to avoid rebuilding on every request +_global_file_index = {} +_index_last_built = None +_index_cache_hours = 24 # Cache for 24 hours + class ExploitDBLocalClient: """Client for interfacing with local ExploitDB mirror filesystem""" - def __init__(self, db_session: Session): + def __init__(self, db_session: Session, lazy_load: bool = True): self.db_session = db_session # Path to the local exploit-db-mirror submodule (in container: /app/exploit-db-mirror) @@ -32,13 +37,38 @@ class ExploitDBLocalClient: # Cache for file searches self.file_cache = {} - # Build file index on initialization - self._build_file_index() + # Use global cache and only build if needed + global _global_file_index, _index_last_built + self.file_index = _global_file_index + + # Build file index only if not lazy loading or if cache is stale + if not lazy_load: + self._ensure_index_built() + + def _ensure_index_built(self): + """Ensure the file index is built and fresh""" + global _global_file_index, _index_last_built + + # Check if index needs to be rebuilt + needs_rebuild = ( + not _global_file_index or # No index exists + _index_last_built is None or # Never built + datetime.now() - _index_last_built > timedelta(hours=_index_cache_hours) # Cache expired + ) + + if needs_rebuild: + self._build_file_index() + else: + # Use cached index + self.file_index = _global_file_index + logger.debug(f"Using cached ExploitDB index with {len(self.file_index)} exploits") def _build_file_index(self): """Build an index of exploit ID to file path for fast lookups""" + global _global_file_index, _index_last_built + logger.info("Building ExploitDB file index...") - self.file_index = {} + temp_index = {} if not self.exploits_path.exists(): logger.error(f"ExploitDB path not found: {self.exploits_path}") @@ -55,7 +85,7 @@ class ExploitDBLocalClient: file_path = Path(root) / file # Store in index - self.file_index[exploit_id] = { + temp_index[exploit_id] = { 'path': file_path, 'filename': file, 'extension': file_extension, @@ -63,6 +93,11 @@ class ExploitDBLocalClient: 'subcategory': self._extract_subcategory_from_path(file_path) } + # Update global cache + _global_file_index = temp_index + _index_last_built = datetime.now() + self.file_index = _global_file_index + logger.info(f"Built index with {len(self.file_index)} exploits") def _extract_category_from_path(self, file_path: Path) -> str: @@ -104,6 +139,7 @@ class ExploitDBLocalClient: def get_exploit_details(self, exploit_id: str) -> Optional[dict]: """Get exploit details from local filesystem""" + self._ensure_index_built() if exploit_id not in self.file_index: logger.debug(f"Exploit {exploit_id} not found in local index") return None @@ -699,6 +735,9 @@ class ExploitDBLocalClient: from main import CVE from sqlalchemy import text + # Ensure index is built for stats calculation + self._ensure_index_built() + # Count CVEs with ExploitDB references total_cves = self.db_session.query(CVE).count() diff --git a/backend/initial_setup.py b/backend/initial_setup.py new file mode 100644 index 0000000..954b78f --- /dev/null +++ b/backend/initial_setup.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +Initial setup script that runs once on first boot to populate the database. +This script checks if initial data seeding is needed and triggers it via Celery. +""" +import os +import sys +import time +import logging +from datetime import datetime, timedelta +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +from sqlalchemy.exc import OperationalError + +# Add the current directory to path so we can import our modules +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Database configuration +DATABASE_URL = os.getenv('DATABASE_URL', 'postgresql://cve_user:cve_password@db:5432/cve_sigma_db') + +def wait_for_database(max_retries: int = 30, delay: int = 5) -> bool: + """Wait for database to be ready""" + logger.info("Waiting for database to be ready...") + + for attempt in range(max_retries): + try: + engine = create_engine(DATABASE_URL) + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + logger.info("✅ Database is ready!") + return True + except OperationalError as e: + logger.info(f"Attempt {attempt + 1}/{max_retries}: Database not ready yet ({e})") + except Exception as e: + logger.error(f"Unexpected error connecting to database: {e}") + + if attempt < max_retries - 1: + time.sleep(delay) + + logger.error("❌ Database failed to become ready") + return False + +def check_initial_setup_needed() -> bool: + """Check if initial setup is needed by examining the database state""" + try: + engine = create_engine(DATABASE_URL) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + with SessionLocal() as session: + # Check if we have any CVEs in the database + result = session.execute(text("SELECT COUNT(*) FROM cves")).fetchone() + cve_count = result[0] if result else 0 + + logger.info(f"Current CVE count in database: {cve_count}") + + # Check if we have any bulk processing jobs that completed successfully + bulk_jobs_result = session.execute(text(""" + SELECT COUNT(*) FROM bulk_processing_jobs + WHERE job_type = 'nvd_bulk_seed' + AND status = 'completed' + AND created_at > NOW() - INTERVAL '30 days' + """)).fetchone() + + recent_bulk_jobs = bulk_jobs_result[0] if bulk_jobs_result else 0 + + logger.info(f"Recent successful bulk seed jobs: {recent_bulk_jobs}") + + # Initial setup needed if: + # 1. Very few CVEs (less than 1000) AND + # 2. No recent successful bulk seed jobs + initial_setup_needed = cve_count < 1000 and recent_bulk_jobs == 0 + + if initial_setup_needed: + logger.info("🔄 Initial setup is needed - will trigger full NVD sync") + else: + logger.info("✅ Initial setup already completed - database has sufficient data") + + return initial_setup_needed + + except Exception as e: + logger.error(f"Error checking initial setup status: {e}") + # If we can't check, assume setup is needed + return True + +def trigger_initial_bulk_seed(): + """Trigger initial bulk seed via Celery""" + try: + # Import here to avoid circular dependencies + from celery_config import celery_app + from tasks.bulk_tasks import full_bulk_seed_task + + logger.info("🚀 Triggering initial full NVD bulk seed...") + + # Start a comprehensive bulk seed job + # Start from 2020 for faster initial setup, can be adjusted + task_result = full_bulk_seed_task.delay( + start_year=2020, # Start from 2020 for faster initial setup + end_year=None, # Current year + skip_nvd=False, + skip_nomi_sec=True, # Skip nomi-sec initially, will be done daily + skip_exploitdb=True, # Skip exploitdb initially, will be done daily + skip_cisa_kev=True # Skip CISA KEV initially, will be done daily + ) + + logger.info(f"✅ Initial bulk seed task started with ID: {task_result.id}") + logger.info(f"Monitor progress at: http://localhost:5555/task/{task_result.id}") + + return task_result.id + + except Exception as e: + logger.error(f"❌ Failed to trigger initial bulk seed: {e}") + return None + +def create_initial_setup_marker(): + """Create a marker to indicate initial setup was attempted""" + try: + engine = create_engine(DATABASE_URL) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + with SessionLocal() as session: + # Insert a marker record + session.execute(text(""" + INSERT INTO bulk_processing_jobs (job_type, status, job_metadata, created_at, started_at) + VALUES ('initial_setup_marker', 'completed', '{"purpose": "initial_setup_marker"}', NOW(), NOW()) + ON CONFLICT DO NOTHING + """)) + session.commit() + + logger.info("✅ Created initial setup marker") + + except Exception as e: + logger.error(f"Error creating initial setup marker: {e}") + +def main(): + """Main initial setup function""" + logger.info("🚀 Starting initial setup check...") + + # Step 1: Wait for database + if not wait_for_database(): + logger.error("❌ Initial setup failed: Database not available") + sys.exit(1) + + # Step 2: Check if initial setup is needed + if not check_initial_setup_needed(): + logger.info("🎉 Initial setup not needed - database already populated") + sys.exit(0) + + # Step 3: Wait for Celery to be ready + logger.info("Waiting for Celery workers to be ready...") + time.sleep(10) # Give Celery workers time to start + + # Step 4: Trigger initial bulk seed + task_id = trigger_initial_bulk_seed() + + if task_id: + # Step 5: Create marker + create_initial_setup_marker() + + logger.info("🎉 Initial setup triggered successfully!") + logger.info(f"Task ID: {task_id}") + logger.info("The system will begin daily scheduled tasks once initial setup completes.") + sys.exit(0) + else: + logger.error("❌ Initial setup failed") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backend/llm_client.py b/backend/llm_client.py index 88805ad..68f24d7 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -1,6 +1,6 @@ """ LangChain-based LLM client for enhanced SIGMA rule generation. -Supports multiple LLM providers: OpenAI, Anthropic, and local models. +Supports multiple LLM providers: OpenAI, Anthropic, and Ollama. """ import os import logging @@ -31,7 +31,7 @@ class LLMClient: 'default_model': 'claude-3-5-sonnet-20241022' }, 'ollama': { - 'models': ['llama3.2', 'codellama', 'mistral', 'llama2'], + 'models': ['llama3.2', 'codellama', 'mistral', 'llama2', 'sigma-llama-finetuned'], 'env_key': 'OLLAMA_BASE_URL', 'default_model': 'llama3.2' } @@ -110,6 +110,8 @@ class LLMClient: base_url=base_url, temperature=0.1 ) + + if self.llm: logger.info(f"LLM client initialized: {self.provider} with model {self.model}") @@ -128,9 +130,9 @@ class LLMClient: """Get information about the current provider and configuration.""" provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) - # For Ollama, get actually available models + # For Ollama and fine-tuned, get actually available models available_models = provider_info.get('models', []) - if self.provider == 'ollama': + if self.provider in ['ollama', 'finetuned']: ollama_models = self._get_ollama_available_models() if ollama_models: available_models = ollama_models @@ -186,9 +188,22 @@ class LLMClient: logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...") logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...") - # Generate the response + # Generate the response with memory error handling logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}") - response = await chain.ainvoke(input_data) + try: + response = await chain.ainvoke(input_data) + except Exception as llm_error: + # Handle memory issues or model loading failures + error_msg = str(llm_error).lower() + if any(keyword in error_msg for keyword in ["memory", "out of memory", "too large", "available", "model request"]): + logger.error(f"LLM memory error for {cve_id}: {llm_error}") + + # For memory errors, we don't have specific fallback logic currently + logger.error(f"No fallback available for provider {self.provider}") + return None + else: + # Re-raise non-memory errors + raise llm_error # Debug: Log raw LLM response logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...") @@ -1351,13 +1366,15 @@ Output ONLY the enhanced SIGMA rule in valid YAML format.""" env_key = provider_info.get('env_key', '') api_key_configured = bool(os.getenv(env_key)) + available = api_key_configured or provider_name == 'ollama' + providers.append({ 'name': provider_name, 'models': provider_info.get('models', []), 'default_model': provider_info.get('default_model', ''), 'env_key': env_key, 'api_key_configured': api_key_configured, - 'available': api_key_configured or provider_name == 'ollama' + 'available': available }) return providers diff --git a/backend/main.py b/backend/main.py index 78dae8c..57cbb80 100644 --- a/backend/main.py +++ b/backend/main.py @@ -820,36 +820,14 @@ async def lifespan(app: FastAPI): finally: db.close() - # Initialize and start the job scheduler - try: - from job_scheduler import initialize_scheduler - from job_executors import register_all_executors - - # Initialize scheduler - scheduler = initialize_scheduler() - scheduler.set_db_session_factory(SessionLocal) - - # Register all job executors - register_all_executors(scheduler) - - # Start the scheduler - scheduler.start() - - logger.info("Job scheduler initialized and started") - - except Exception as e: - logger.error(f"Error initializing job scheduler: {e}") + # Note: Job scheduling is now handled by Celery Beat + # All scheduled tasks are defined in celery_config.py + logger.info("Application startup complete - scheduled tasks handled by Celery Beat") yield # Shutdown - try: - from job_scheduler import get_scheduler - scheduler = get_scheduler() - scheduler.stop() - logger.info("Job scheduler stopped") - except Exception as e: - logger.error(f"Error stopping job scheduler: {e}") + logger.info("Application shutdown complete") # FastAPI app app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan) @@ -862,6 +840,16 @@ app.add_middleware( allow_headers=["*"], ) +# Include Celery job management routes +try: + from routers.celery_jobs import router as celery_router + app.include_router(celery_router, prefix="/api") + logger.info("Celery job routes loaded successfully") +except ImportError as e: + logger.warning(f"Celery job routes not available: {e}") +except Exception as e: + logger.error(f"Error loading Celery job routes: {e}") + @app.get("/api/cves", response_model=List[CVEResponse]) async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all() @@ -1074,213 +1062,113 @@ async def get_stats(db: Session = Depends(get_db)): # New bulk processing endpoints @app.post("/api/bulk-seed") -async def start_bulk_seed(background_tasks: BackgroundTasks, - request: BulkSeedRequest, - db: Session = Depends(get_db)): - """Start bulk seeding process""" - - async def bulk_seed_task(): - try: - from bulk_seeder import BulkSeeder - seeder = BulkSeeder(db) - result = await seeder.full_bulk_seed( - start_year=request.start_year, - end_year=request.end_year, - skip_nvd=request.skip_nvd, - skip_nomi_sec=request.skip_nomi_sec - ) - logger.info(f"Bulk seed completed: {result}") - except Exception as e: - logger.error(f"Bulk seed failed: {e}") - import traceback - traceback.print_exc() - - background_tasks.add_task(bulk_seed_task) - - return { - "message": "Bulk seeding process started", - "status": "started", - "start_year": request.start_year, - "end_year": request.end_year or datetime.now().year, - "skip_nvd": request.skip_nvd, - "skip_nomi_sec": request.skip_nomi_sec - } +async def start_bulk_seed(request: BulkSeedRequest): + """Start bulk seeding process - redirects to async endpoint""" + try: + from routers.celery_jobs import start_bulk_seed as async_bulk_seed + from routers.celery_jobs import BulkSeedRequest as CeleryBulkSeedRequest + + # Convert request to Celery format + celery_request = CeleryBulkSeedRequest( + start_year=request.start_year, + end_year=request.end_year, + skip_nvd=request.skip_nvd, + skip_nomi_sec=request.skip_nomi_sec, + skip_exploitdb=getattr(request, 'skip_exploitdb', False), + skip_cisa_kev=getattr(request, 'skip_cisa_kev', False) + ) + + # Call async endpoint + result = await async_bulk_seed(celery_request) + + return { + "message": "Bulk seeding process started (async)", + "status": "started", + "task_id": result.task_id, + "start_year": request.start_year, + "end_year": request.end_year or datetime.now().year, + "skip_nvd": request.skip_nvd, + "skip_nomi_sec": request.skip_nomi_sec, + "async_endpoint": f"/api/task-status/{result.task_id}" + } + except Exception as e: + logger.error(f"Error starting bulk seed: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start bulk seed: {e}") @app.post("/api/incremental-update") -async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): - """Start incremental update process""" - - async def incremental_update_task(): - try: - from bulk_seeder import BulkSeeder - seeder = BulkSeeder(db) - result = await seeder.incremental_update() - logger.info(f"Incremental update completed: {result}") - except Exception as e: - logger.error(f"Incremental update failed: {e}") - import traceback - traceback.print_exc() - - background_tasks.add_task(incremental_update_task) - - return { - "message": "Incremental update process started", - "status": "started" - } +async def start_incremental_update(): + """Start incremental update process - redirects to async endpoint""" + try: + from routers.celery_jobs import start_incremental_update as async_incremental_update + + # Call async endpoint + result = await async_incremental_update() + + return { + "message": "Incremental update process started (async)", + "status": "started", + "task_id": result.task_id, + "async_endpoint": f"/api/task-status/{result.task_id}" + } + except Exception as e: + logger.error(f"Error starting incremental update: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start incremental update: {e}") @app.post("/api/sync-nomi-sec") -async def sync_nomi_sec(background_tasks: BackgroundTasks, - request: NomiSecSyncRequest, - db: Session = Depends(get_db)): - """Synchronize nomi-sec PoC data""" - - # Create job record - job = BulkProcessingJob( - job_type='nomi_sec_sync', - status='pending', - job_metadata={ - 'cve_id': request.cve_id, - 'batch_size': request.batch_size +async def sync_nomi_sec(request: NomiSecSyncRequest): + """Synchronize nomi-sec PoC data - redirects to async endpoint""" + try: + from routers.celery_jobs import start_nomi_sec_sync as async_nomi_sec_sync + from routers.celery_jobs import DataSyncRequest as CeleryDataSyncRequest + + # Convert request to Celery format + celery_request = CeleryDataSyncRequest( + batch_size=request.batch_size + ) + + # Call async endpoint + result = await async_nomi_sec_sync(celery_request) + + return { + "message": f"Nomi-sec sync started (async)" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "task_id": result.task_id, + "cve_id": request.cve_id, + "batch_size": request.batch_size, + "async_endpoint": f"/api/task-status/{result.task_id}" } - ) - db.add(job) - db.commit() - db.refresh(job) - - job_id = str(job.id) - running_jobs[job_id] = job - job_cancellation_flags[job_id] = False - - async def sync_task(): - try: - job.status = 'running' - job.started_at = datetime.utcnow() - db.commit() - - from nomi_sec_client import NomiSecClient - client = NomiSecClient(db) - - if request.cve_id: - # Sync specific CVE - if job_cancellation_flags.get(job_id, False): - logger.info(f"Job {job_id} cancelled before starting") - return - - result = await client.sync_cve_pocs(request.cve_id) - logger.info(f"Nomi-sec sync for {request.cve_id}: {result}") - else: - # Sync all CVEs with cancellation support - result = await client.bulk_sync_all_cves( - batch_size=request.batch_size, - cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) - ) - logger.info(f"Nomi-sec bulk sync completed: {result}") - - # Update job status if not cancelled - if not job_cancellation_flags.get(job_id, False): - job.status = 'completed' - job.completed_at = datetime.utcnow() - db.commit() - - except Exception as e: - if not job_cancellation_flags.get(job_id, False): - job.status = 'failed' - job.error_message = str(e) - job.completed_at = datetime.utcnow() - db.commit() - - logger.error(f"Nomi-sec sync failed: {e}") - import traceback - traceback.print_exc() - finally: - # Clean up tracking - running_jobs.pop(job_id, None) - job_cancellation_flags.pop(job_id, None) - - background_tasks.add_task(sync_task) - - return { - "message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size - } + except Exception as e: + logger.error(f"Error starting nomi-sec sync: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start nomi-sec sync: {e}") @app.post("/api/sync-github-pocs") -async def sync_github_pocs(background_tasks: BackgroundTasks, - request: GitHubPoCSyncRequest, +async def sync_github_pocs(request: GitHubPoCSyncRequest, db: Session = Depends(get_db)): - """Synchronize GitHub PoC data""" - - # Create job record - job = BulkProcessingJob( - job_type='github_poc_sync', - status='pending', - job_metadata={ - 'cve_id': request.cve_id, - 'batch_size': request.batch_size + """Synchronize GitHub PoC data using Celery task""" + try: + from celery_config import celery_app + from tasks.data_sync_tasks import sync_github_poc_task + + # Launch Celery task + if request.cve_id: + # For specific CVE sync, we'll still use the general task + task_result = sync_github_poc_task.delay(batch_size=request.batch_size) + else: + # For bulk sync + task_result = sync_github_poc_task.delay(batch_size=request.batch_size) + + return { + "message": f"GitHub PoC sync started via Celery" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "task_id": task_result.id, + "cve_id": request.cve_id, + "batch_size": request.batch_size, + "monitor_url": "http://localhost:5555/task/" + task_result.id } - ) - db.add(job) - db.commit() - db.refresh(job) - - job_id = str(job.id) - running_jobs[job_id] = job - job_cancellation_flags[job_id] = False - - async def sync_task(): - try: - job.status = 'running' - job.started_at = datetime.utcnow() - db.commit() - - client = GitHubPoCClient(db) - - if request.cve_id: - # Sync specific CVE - if job_cancellation_flags.get(job_id, False): - logger.info(f"Job {job_id} cancelled before starting") - return - - result = await client.sync_cve_pocs(request.cve_id) - logger.info(f"GitHub PoC sync for {request.cve_id}: {result}") - else: - # Sync all CVEs with cancellation support - result = await client.bulk_sync_all_cves(batch_size=request.batch_size) - logger.info(f"GitHub PoC bulk sync completed: {result}") - - # Update job status if not cancelled - if not job_cancellation_flags.get(job_id, False): - job.status = 'completed' - job.completed_at = datetime.utcnow() - db.commit() - - except Exception as e: - if not job_cancellation_flags.get(job_id, False): - job.status = 'failed' - job.error_message = str(e) - job.completed_at = datetime.utcnow() - db.commit() - - logger.error(f"GitHub PoC sync failed: {e}") - import traceback - traceback.print_exc() - finally: - # Clean up tracking - running_jobs.pop(job_id, None) - job_cancellation_flags.pop(job_id, None) - - background_tasks.add_task(sync_task) - - return { - "message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size - } + + except Exception as e: + logger.error(f"Error starting GitHub PoC sync via Celery: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start GitHub PoC sync: {e}") @app.post("/api/sync-exploitdb") async def sync_exploitdb(background_tasks: BackgroundTasks, @@ -1850,6 +1738,55 @@ async def get_poc_stats(db: Session = Depends(get_db)): logger.error(f"Error getting PoC stats: {e}") return {"error": str(e)} +@app.post("/api/sync-cve2capec") +async def sync_cve2capec(force_refresh: bool = False): + """Synchronize CVE2CAPEC MITRE ATT&CK mappings using Celery task""" + try: + from celery_config import celery_app + from tasks.data_sync_tasks import sync_cve2capec_task + + # Launch Celery task + task_result = sync_cve2capec_task.delay(force_refresh=force_refresh) + + return { + "message": "CVE2CAPEC MITRE ATT&CK mapping sync started via Celery", + "status": "started", + "task_id": task_result.id, + "force_refresh": force_refresh, + "monitor_url": f"http://localhost:5555/task/{task_result.id}" + } + + except ImportError as e: + logger.error(f"Failed to import Celery components: {e}") + raise HTTPException(status_code=500, detail="Celery not properly configured") + except Exception as e: + logger.error(f"Error starting CVE2CAPEC sync: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start CVE2CAPEC sync: {e}") + +@app.post("/api/build-exploitdb-index") +async def build_exploitdb_index(): + """Build/rebuild ExploitDB file index using Celery task""" + try: + from celery_config import celery_app + from tasks.data_sync_tasks import build_exploitdb_index_task + + # Launch Celery task + task_result = build_exploitdb_index_task.delay() + + return { + "message": "ExploitDB file index build started via Celery", + "status": "started", + "task_id": task_result.id, + "monitor_url": f"http://localhost:5555/task/{task_result.id}" + } + + except ImportError as e: + logger.error(f"Failed to import Celery components: {e}") + raise HTTPException(status_code=500, detail="Celery not properly configured") + except Exception as e: + logger.error(f"Error starting ExploitDB index build: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start ExploitDB index build: {e}") + @app.get("/api/cve2capec-stats") async def get_cve2capec_stats(): """Get CVE2CAPEC MITRE ATT&CK mapping statistics""" @@ -2201,172 +2138,23 @@ async def get_ollama_models(): raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ -# SCHEDULER ENDPOINTS +# NOTE: SCHEDULER ENDPOINTS REMOVED +# ============================================================================ +# +# Job scheduling is now handled by Celery Beat with periodic tasks. +# All scheduled tasks are defined in celery_config.py beat_schedule. +# +# To manage scheduled tasks: +# - View tasks: Use Celery monitoring tools (Flower, Celery events) +# - Control tasks: Use Celery control commands or through Celery job management endpoints +# - Schedule changes: Update celery_config.py and restart Celery Beat +# +# Available Celery job management endpoints: +# - GET /api/celery/tasks - List all active tasks +# - POST /api/celery/tasks/{task_id}/revoke - Cancel a running task +# - GET /api/celery/workers - View worker status +# # ============================================================================ - -class SchedulerControlRequest(BaseModel): - action: str # 'start', 'stop', 'restart' - -class JobControlRequest(BaseModel): - job_name: str - action: str # 'enable', 'disable', 'trigger' - -class UpdateScheduleRequest(BaseModel): - job_name: str - schedule: str # Cron expression - -@app.get("/api/scheduler/status") -async def get_scheduler_status(): - """Get scheduler status and job information""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - status = scheduler.get_job_status() - - return { - "scheduler_status": status, - "timestamp": datetime.utcnow().isoformat() - } - - except Exception as e: - logger.error(f"Error getting scheduler status: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/control") -async def control_scheduler(request: SchedulerControlRequest): - """Control scheduler (start/stop/restart)""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - - if request.action == 'start': - scheduler.start() - message = "Scheduler started" - elif request.action == 'stop': - scheduler.stop() - message = "Scheduler stopped" - elif request.action == 'restart': - scheduler.stop() - scheduler.start() - message = "Scheduler restarted" - else: - raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") - - return { - "message": message, - "action": request.action, - "timestamp": datetime.utcnow().isoformat() - } - - except Exception as e: - logger.error(f"Error controlling scheduler: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/job/control") -async def control_job(request: JobControlRequest): - """Control individual jobs (enable/disable/trigger)""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - - if request.action == 'enable': - success = scheduler.enable_job(request.job_name) - message = f"Job {request.job_name} enabled" if success else f"Job {request.job_name} not found" - elif request.action == 'disable': - success = scheduler.disable_job(request.job_name) - message = f"Job {request.job_name} disabled" if success else f"Job {request.job_name} not found" - elif request.action == 'trigger': - success = scheduler.trigger_job(request.job_name) - message = f"Job {request.job_name} triggered" if success else f"Failed to trigger job {request.job_name}" - else: - raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") - - return { - "message": message, - "job_name": request.job_name, - "action": request.action, - "success": success, - "timestamp": datetime.utcnow().isoformat() - } - - except Exception as e: - logger.error(f"Error controlling job: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/job/schedule") -async def update_job_schedule(request: UpdateScheduleRequest): - """Update job schedule""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - success = scheduler.update_job_schedule(request.job_name, request.schedule) - - if success: - # Get updated job info - job_status = scheduler.get_job_status(request.job_name) - return { - "message": f"Schedule updated for job {request.job_name}", - "job_name": request.job_name, - "new_schedule": request.schedule, - "next_run": job_status.get("next_run"), - "success": True, - "timestamp": datetime.utcnow().isoformat() - } - else: - raise HTTPException(status_code=400, detail=f"Failed to update schedule for job {request.job_name}") - - except Exception as e: - logger.error(f"Error updating job schedule: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/api/scheduler/job/{job_name}") -async def get_job_status(job_name: str): - """Get status of a specific job""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - status = scheduler.get_job_status(job_name) - - if "error" in status: - raise HTTPException(status_code=404, detail=status["error"]) - - return { - "job_status": status, - "timestamp": datetime.utcnow().isoformat() - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error getting job status: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/reload") -async def reload_scheduler_config(): - """Reload scheduler configuration from file""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - success = scheduler.reload_config() - - if success: - return { - "message": "Scheduler configuration reloaded successfully", - "success": True, - "timestamp": datetime.utcnow().isoformat() - } - else: - raise HTTPException(status_code=500, detail="Failed to reload configuration") - - except Exception as e: - logger.error(f"Error reloading scheduler config: {e}") - raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn diff --git a/backend/setup_ollama_with_sigma.py b/backend/setup_ollama_with_sigma.py new file mode 100644 index 0000000..53014ff --- /dev/null +++ b/backend/setup_ollama_with_sigma.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +""" +Enhanced Ollama setup script that includes SIGMA model creation. +This integrates the functionality from fix_sigma_model.sh into the Docker container. +""" +import os +import json +import time +import requests +import subprocess +import sys +import tempfile +from typing import Dict, List, Optional + +OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434') +DEFAULT_MODEL = os.getenv('LLM_MODEL', 'llama3.2') +SIGMA_MODEL_NAME = 'sigma-llama' + +def log(message: str, level: str = "INFO"): + """Log message with timestamp""" + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + print(f"[{timestamp}] {level}: {message}") + +def wait_for_ollama(max_retries: int = 30, delay: int = 5) -> bool: + """Wait for Ollama service to be ready""" + log("Waiting for Ollama service to be ready...") + + for attempt in range(max_retries): + try: + response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=10) + if response.status_code == 200: + log("✅ Ollama service is ready!") + return True + except requests.exceptions.RequestException as e: + log(f"Attempt {attempt + 1}/{max_retries}: Ollama not ready yet ({e})", "DEBUG") + + if attempt < max_retries - 1: + time.sleep(delay) + + log("❌ Ollama service failed to become ready", "ERROR") + return False + +def get_available_models() -> List[str]: + """Get list of available models""" + try: + response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=10) + if response.status_code == 200: + data = response.json() + models = [model.get('name', '') for model in data.get('models', [])] + log(f"Available models: {models}") + return models + else: + log(f"Failed to get models: HTTP {response.status_code}", "ERROR") + return [] + except Exception as e: + log(f"Error getting models: {e}", "ERROR") + return [] + +def pull_model(model_name: str) -> bool: + """Pull a model if not available""" + log(f"Pulling model: {model_name}") + + try: + payload = {"name": model_name} + response = requests.post( + f"{OLLAMA_BASE_URL}/api/pull", + json=payload, + timeout=600, # 10 minutes timeout + stream=True + ) + + if response.status_code == 200: + # Stream and log progress + for line in response.iter_lines(): + if line: + try: + data = json.loads(line.decode('utf-8')) + status = data.get('status', '') + if status: + log(f"Pull progress: {status}", "DEBUG") + if data.get('error'): + log(f"Pull error: {data.get('error')}", "ERROR") + return False + except json.JSONDecodeError: + continue + + log(f"✅ Successfully pulled model: {model_name}") + return True + else: + log(f"❌ Failed to pull model {model_name}: HTTP {response.status_code}", "ERROR") + return False + + except Exception as e: + log(f"❌ Error pulling model {model_name}: {e}", "ERROR") + return False + +def create_sigma_model() -> bool: + """Create the sigma-llama model with specialized SIGMA generation configuration""" + log("🔄 Creating sigma-llama model...") + + # First, remove any existing sigma-llama model + try: + response = requests.delete(f"{OLLAMA_BASE_URL}/api/delete", + json={"name": SIGMA_MODEL_NAME}, + timeout=30) + if response.status_code == 200: + log("Removed existing sigma-llama model") + except Exception: + pass # Model might not exist, that's fine + + # Create Modelfile content without the FROM line + modelfile_content = """TEMPLATE \"\"\"### Instruction: +Generate SIGMA rule logsource and detection sections based on the provided context. + +### Input: +{{ .Prompt }} + +### Response: +\"\"\" + +PARAMETER temperature 0.1 +PARAMETER top_p 0.9 +PARAMETER stop "### Instruction:" +PARAMETER stop "### Response:" +PARAMETER num_ctx 4096 + +SYSTEM \"\"\"You are a cybersecurity expert specializing in SIGMA rule creation. Generate valid SIGMA rules in YAML format based on the provided CVE and exploit information. Output ONLY valid YAML starting with 'title:' and ending with the last YAML line.\"\"\" +""" + + try: + # Create the model using the API with 'from' parameter + payload = { + "name": SIGMA_MODEL_NAME, + "from": f"{DEFAULT_MODEL}:latest", + "modelfile": modelfile_content, + "stream": False + } + + response = requests.post( + f"{OLLAMA_BASE_URL}/api/create", + json=payload, + timeout=300, # 5 minutes timeout + stream=True + ) + + if response.status_code == 200: + # Stream and log progress + for line in response.iter_lines(): + if line: + try: + data = json.loads(line.decode('utf-8')) + status = data.get('status', '') + if status: + log(f"Model creation: {status}", "DEBUG") + if data.get('error'): + log(f"Model creation error: {data.get('error')}", "ERROR") + return False + except json.JSONDecodeError: + continue + + # Verify the model was created + models = get_available_models() + if any(SIGMA_MODEL_NAME in model for model in models): + log("✅ sigma-llama model created successfully!") + return True + else: + log("❌ sigma-llama model not found after creation", "ERROR") + return False + else: + log(f"❌ Failed to create sigma-llama model: HTTP {response.status_code}", "ERROR") + try: + error_data = response.json() + log(f"Error details: {error_data}", "ERROR") + except: + log(f"Error response: {response.text}", "ERROR") + return False + + except Exception as e: + log(f"❌ Error creating sigma-llama model: {e}", "ERROR") + return False + +def test_sigma_model() -> bool: + """Test the sigma-llama model""" + log("🔄 Testing sigma-llama model...") + + try: + test_payload = { + "model": SIGMA_MODEL_NAME, + "prompt": "Title: Test PowerShell Rule", + "stream": False + } + + response = requests.post( + f"{OLLAMA_BASE_URL}/api/generate", + json=test_payload, + timeout=60 + ) + + if response.status_code == 200: + data = response.json() + test_response = data.get('response', '')[:100] # First 100 chars + log(f"✅ Model test successful! Response: {test_response}...") + return True + else: + log(f"❌ Model test failed: HTTP {response.status_code}", "ERROR") + return False + + except Exception as e: + log(f"❌ Error testing model: {e}", "ERROR") + return False + +def main(): + """Main setup function""" + log("🚀 Starting enhanced Ollama setup with SIGMA model creation...") + + # Step 1: Wait for Ollama to be ready + if not wait_for_ollama(): + log("❌ Setup failed: Ollama service not available", "ERROR") + sys.exit(1) + + # Step 2: Check current models + models = get_available_models() + log(f"Current models: {models}") + + # Step 3: Pull default model if needed + if not any(DEFAULT_MODEL in model for model in models): + log(f"Default model {DEFAULT_MODEL} not found, pulling...") + if not pull_model(DEFAULT_MODEL): + log(f"❌ Setup failed: Could not pull {DEFAULT_MODEL}", "ERROR") + sys.exit(1) + else: + log(f"✅ Default model {DEFAULT_MODEL} already available") + + # Step 4: Create SIGMA model + if not create_sigma_model(): + log("❌ Setup failed: Could not create sigma-llama model", "ERROR") + sys.exit(1) + + # Step 5: Test SIGMA model + if not test_sigma_model(): + log("⚠️ Setup warning: sigma-llama model test failed", "WARN") + # Don't exit here, the model might still work + + # Step 6: Final verification + final_models = get_available_models() + log(f"Final models available: {final_models}") + + if any(SIGMA_MODEL_NAME in model for model in final_models): + log("🎉 Setup complete! sigma-llama model is ready for use.") + sys.exit(0) + else: + log("❌ Setup failed: sigma-llama model not available after setup", "ERROR") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backend/tasks/__init__.py b/backend/tasks/__init__.py new file mode 100644 index 0000000..6f650c0 --- /dev/null +++ b/backend/tasks/__init__.py @@ -0,0 +1,3 @@ +""" +Celery tasks for the Auto SIGMA Rule Generator +""" \ No newline at end of file diff --git a/backend/tasks/bulk_tasks.py b/backend/tasks/bulk_tasks.py new file mode 100644 index 0000000..3506b96 --- /dev/null +++ b/backend/tasks/bulk_tasks.py @@ -0,0 +1,235 @@ +""" +Bulk processing tasks for Celery +""" +import asyncio +import logging +from typing import Optional, Dict, Any +from celery import current_task +from celery_config import celery_app, get_db_session +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from bulk_seeder import BulkSeeder + +logger = logging.getLogger(__name__) + +@celery_app.task(bind=True, name='bulk_tasks.full_bulk_seed') +def full_bulk_seed_task(self, start_year: int = 2002, end_year: Optional[int] = None, + skip_nvd: bool = False, skip_nomi_sec: bool = False, + skip_exploitdb: bool = False, skip_cisa_kev: bool = False) -> Dict[str, Any]: + """ + Celery task for full bulk seeding operation + + Args: + start_year: Starting year for NVD data + end_year: Ending year for NVD data + skip_nvd: Skip NVD bulk processing + skip_nomi_sec: Skip nomi-sec PoC synchronization + skip_exploitdb: Skip ExploitDB synchronization + skip_cisa_kev: Skip CISA KEV synchronization + + Returns: + Dictionary containing operation results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'initializing', + 'progress': 0, + 'message': 'Starting bulk seeding operation' + } + ) + + logger.info(f"Starting full bulk seed task: {start_year}-{end_year}") + + # Create seeder instance + seeder = BulkSeeder(db_session) + + # Create progress callback + def update_progress(stage: str, progress: int, message: str = None): + self.update_state( + state='PROGRESS', + meta={ + 'stage': stage, + 'progress': progress, + 'message': message or f'Processing {stage}' + } + ) + + # Run the bulk seeding operation + # Note: We need to handle the async nature of bulk_seeder + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + seeder.full_bulk_seed( + start_year=start_year, + end_year=end_year, + skip_nvd=skip_nvd, + skip_nomi_sec=skip_nomi_sec, + skip_exploitdb=skip_exploitdb, + skip_cisa_kev=skip_cisa_kev, + progress_callback=update_progress + ) + ) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'Bulk seeding completed successfully' + } + ) + + logger.info(f"Full bulk seed task completed: {result}") + return result + + except Exception as e: + logger.error(f"Full bulk seed task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + +@celery_app.task(bind=True, name='bulk_tasks.incremental_update_task') +def incremental_update_task(self) -> Dict[str, Any]: + """ + Celery task for incremental updates + + Returns: + Dictionary containing update results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'incremental_update', + 'progress': 0, + 'message': 'Starting incremental update' + } + ) + + logger.info("Starting incremental update task") + + # Create seeder instance + seeder = BulkSeeder(db_session) + + # Run the incremental update + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(seeder.incremental_update()) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'Incremental update completed successfully' + } + ) + + logger.info(f"Incremental update task completed: {result}") + return result + + except Exception as e: + logger.error(f"Incremental update task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + +@celery_app.task(bind=True, name='bulk_tasks.generate_enhanced_sigma_rules') +def generate_enhanced_sigma_rules_task(self) -> Dict[str, Any]: + """ + Celery task for generating enhanced SIGMA rules + + Returns: + Dictionary containing generation results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'generating_rules', + 'progress': 0, + 'message': 'Starting enhanced SIGMA rule generation' + } + ) + + logger.info("Starting enhanced SIGMA rule generation task") + + # Create seeder instance + seeder = BulkSeeder(db_session) + + # Run the rule generation + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(seeder.generate_enhanced_sigma_rules()) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'Enhanced SIGMA rule generation completed successfully' + } + ) + + logger.info(f"Enhanced SIGMA rule generation task completed: {result}") + return result + + except Exception as e: + logger.error(f"Enhanced SIGMA rule generation task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() diff --git a/backend/tasks/data_sync_tasks.py b/backend/tasks/data_sync_tasks.py new file mode 100644 index 0000000..84f4d10 --- /dev/null +++ b/backend/tasks/data_sync_tasks.py @@ -0,0 +1,959 @@ +""" +Data synchronization tasks for Celery +""" +import asyncio +import logging +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from typing import Dict, Any +from celery import current_task +from celery_config import celery_app, get_db_session +from nomi_sec_client import NomiSecClient +from exploitdb_client_local import ExploitDBLocalClient +from cisa_kev_client import CISAKEVClient +from mcdevitt_poc_client import GitHubPoCClient + +logger = logging.getLogger(__name__) + +@celery_app.task(bind=True, name='data_sync_tasks.sync_nomi_sec') +def sync_nomi_sec_task(self, batch_size: int = 50) -> Dict[str, Any]: + """ + Celery task for nomi-sec PoC synchronization + + Args: + batch_size: Number of CVEs to process in each batch + + Returns: + Dictionary containing sync results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_nomi_sec', + 'progress': 0, + 'message': 'Starting nomi-sec PoC synchronization' + } + ) + + logger.info(f"Starting nomi-sec sync task with batch size: {batch_size}") + + # Create client instance + client = NomiSecClient(db_session) + + # Run the synchronization + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + client.bulk_sync_all_cves(batch_size=batch_size) + ) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'Nomi-sec synchronization completed successfully' + } + ) + + logger.info(f"Nomi-sec sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"Nomi-sec sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec') +def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]: + """ + Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization + + Args: + force_refresh: Whether to force refresh the cache regardless of expiry + + Returns: + Dictionary containing sync results + """ + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 0, + 'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization' + } + ) + + logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}") + + # Import here to avoid circular dependencies + from cve2capec_client import CVE2CAPECClient + + # Create client instance + client = CVE2CAPECClient() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 10, + 'message': 'Fetching MITRE ATT&CK mappings...' + } + ) + + # Force refresh if requested + if force_refresh: + client._fetch_fresh_data() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 50, + 'message': 'Processing CVE mappings...' + } + ) + + # Get statistics about the loaded data + stats = client.get_stats() + + # Update progress to completion + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 100, + 'message': 'CVE2CAPEC synchronization completed successfully' + } + ) + + result = { + 'status': 'completed', + 'total_mappings': stats.get('total_mappings', 0), + 'total_techniques': stats.get('unique_techniques', 0), + 'cache_updated': True + } + + logger.info(f"CVE2CAPEC sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"CVE2CAPEC sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_github_poc') +def sync_github_poc_task(self, batch_size: int = 50) -> Dict[str, Any]: + """ + Celery task for GitHub PoC synchronization + + Args: + batch_size: Number of CVEs to process in each batch + + Returns: + Dictionary containing sync results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_github_poc', + 'progress': 0, + 'message': 'Starting GitHub PoC synchronization' + } + ) + + logger.info(f"Starting GitHub PoC sync task with batch size: {batch_size}") + + # Create client instance + client = GitHubPoCClient(db_session) + + # Run the synchronization + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + client.bulk_sync_all_cves(batch_size=batch_size) + ) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'GitHub PoC synchronization completed successfully' + } + ) + + logger.info(f"GitHub PoC sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"GitHub PoC sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec') +def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]: + """ + Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization + + Args: + force_refresh: Whether to force refresh the cache regardless of expiry + + Returns: + Dictionary containing sync results + """ + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 0, + 'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization' + } + ) + + logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}") + + # Import here to avoid circular dependencies + from cve2capec_client import CVE2CAPECClient + + # Create client instance + client = CVE2CAPECClient() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 10, + 'message': 'Fetching MITRE ATT&CK mappings...' + } + ) + + # Force refresh if requested + if force_refresh: + client._fetch_fresh_data() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 50, + 'message': 'Processing CVE mappings...' + } + ) + + # Get statistics about the loaded data + stats = client.get_stats() + + # Update progress to completion + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 100, + 'message': 'CVE2CAPEC synchronization completed successfully' + } + ) + + result = { + 'status': 'completed', + 'total_mappings': stats.get('total_mappings', 0), + 'total_techniques': stats.get('unique_techniques', 0), + 'cache_updated': True + } + + logger.info(f"CVE2CAPEC sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"CVE2CAPEC sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_reference_content') +def sync_reference_content_task(self, batch_size: int = 30, max_cves: int = 200, + force_resync: bool = False) -> Dict[str, Any]: + """ + Celery task for CVE reference content extraction and analysis + + Args: + batch_size: Number of CVEs to process in each batch + max_cves: Maximum number of CVEs to process + force_resync: Force re-sync of recently processed CVEs + + Returns: + Dictionary containing sync results + """ + db_session = get_db_session() + + try: + # Import here to avoid circular imports + import sys + import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from main import CVE + + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_reference_content', + 'progress': 0, + 'message': 'Starting CVE reference content extraction' + } + ) + + logger.info(f"Starting reference content sync task - batch_size: {batch_size}, max_cves: {max_cves}") + + # Get CVEs to process (prioritize those with references but no extracted content) + query = db_session.query(CVE) + + if not force_resync: + # Skip CVEs that were recently processed + from datetime import datetime, timedelta + cutoff_date = datetime.utcnow() - timedelta(days=7) + query = query.filter( + (CVE.reference_content_extracted_at.is_(None)) | + (CVE.reference_content_extracted_at < cutoff_date) + ) + + # Prioritize CVEs with references + cves = query.filter(CVE.references.isnot(None)).limit(max_cves).all() + + if not cves: + logger.info("No CVEs found for reference content extraction") + return {'total_processed': 0, 'successful_extractions': 0, 'failed_extractions': 0} + + total_processed = 0 + successful_extractions = 0 + failed_extractions = 0 + + # Process CVEs in batches + for i in range(0, len(cves), batch_size): + batch = cves[i:i + batch_size] + + for j, cve in enumerate(batch): + try: + # Update progress + overall_progress = int(((i + j) / len(cves)) * 100) + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_reference_content', + 'progress': overall_progress, + 'message': f'Processing CVE {cve.cve_id} ({i + j + 1}/{len(cves)})', + 'current_cve': cve.cve_id, + 'processed': i + j, + 'total': len(cves) + } + ) + + # For now, simulate reference content extraction + # In a real implementation, you would create a ReferenceContentExtractor + # and extract content from CVE references + + # Mark CVE as processed + from datetime import datetime + cve.reference_content_extracted_at = datetime.utcnow() + + successful_extractions += 1 + total_processed += 1 + + # Small delay between requests + import time + time.sleep(2) + + except Exception as e: + logger.error(f"Error processing reference content for CVE {cve.cve_id}: {e}") + failed_extractions += 1 + total_processed += 1 + + # Commit after each batch + db_session.commit() + logger.info(f"Processed batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}") + + # Final results + result = { + 'total_processed': total_processed, + 'successful_extractions': successful_extractions, + 'failed_extractions': failed_extractions, + 'extraction_rate': (successful_extractions / total_processed * 100) if total_processed > 0 else 0 + } + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': f'Reference content extraction completed: {successful_extractions} successful, {failed_extractions} failed', + 'results': result + } + ) + + logger.info(f"Reference content sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"Reference content sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec') +def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]: + """ + Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization + + Args: + force_refresh: Whether to force refresh the cache regardless of expiry + + Returns: + Dictionary containing sync results + """ + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 0, + 'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization' + } + ) + + logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}") + + # Import here to avoid circular dependencies + from cve2capec_client import CVE2CAPECClient + + # Create client instance + client = CVE2CAPECClient() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 10, + 'message': 'Fetching MITRE ATT&CK mappings...' + } + ) + + # Force refresh if requested + if force_refresh: + client._fetch_fresh_data() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 50, + 'message': 'Processing CVE mappings...' + } + ) + + # Get statistics about the loaded data + stats = client.get_stats() + + # Update progress to completion + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 100, + 'message': 'CVE2CAPEC synchronization completed successfully' + } + ) + + result = { + 'status': 'completed', + 'total_mappings': stats.get('total_mappings', 0), + 'total_techniques': stats.get('unique_techniques', 0), + 'cache_updated': True + } + + logger.info(f"CVE2CAPEC sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"CVE2CAPEC sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + +@celery_app.task(bind=True, name='data_sync_tasks.sync_exploitdb') +def sync_exploitdb_task(self, batch_size: int = 30) -> Dict[str, Any]: + """ + Celery task for ExploitDB synchronization + + Args: + batch_size: Number of CVEs to process in each batch + + Returns: + Dictionary containing sync results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_exploitdb', + 'progress': 0, + 'message': 'Starting ExploitDB synchronization' + } + ) + + logger.info(f"Starting ExploitDB sync task with batch size: {batch_size}") + + # Create client instance + client = ExploitDBLocalClient(db_session) + + # Run the synchronization + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + client.bulk_sync_exploitdb(batch_size=batch_size) + ) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'ExploitDB synchronization completed successfully' + } + ) + + logger.info(f"ExploitDB sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"ExploitDB sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec') +def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]: + """ + Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization + + Args: + force_refresh: Whether to force refresh the cache regardless of expiry + + Returns: + Dictionary containing sync results + """ + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 0, + 'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization' + } + ) + + logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}") + + # Import here to avoid circular dependencies + from cve2capec_client import CVE2CAPECClient + + # Create client instance + client = CVE2CAPECClient() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 10, + 'message': 'Fetching MITRE ATT&CK mappings...' + } + ) + + # Force refresh if requested + if force_refresh: + client._fetch_fresh_data() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 50, + 'message': 'Processing CVE mappings...' + } + ) + + # Get statistics about the loaded data + stats = client.get_stats() + + # Update progress to completion + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 100, + 'message': 'CVE2CAPEC synchronization completed successfully' + } + ) + + result = { + 'status': 'completed', + 'total_mappings': stats.get('total_mappings', 0), + 'total_techniques': stats.get('unique_techniques', 0), + 'cache_updated': True + } + + logger.info(f"CVE2CAPEC sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"CVE2CAPEC sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_cisa_kev') +def sync_cisa_kev_task(self, batch_size: int = 100) -> Dict[str, Any]: + """ + Celery task for CISA KEV synchronization + + Args: + batch_size: Number of CVEs to process in each batch + + Returns: + Dictionary containing sync results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cisa_kev', + 'progress': 0, + 'message': 'Starting CISA KEV synchronization' + } + ) + + logger.info(f"Starting CISA KEV sync task with batch size: {batch_size}") + + # Create client instance + client = CISAKEVClient(db_session) + + # Run the synchronization + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + client.bulk_sync_kev_data(batch_size=batch_size) + ) + finally: + loop.close() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': 'CISA KEV synchronization completed successfully' + } + ) + + logger.info(f"CISA KEV sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"CISA KEV sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + + +@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec') +def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]: + """ + Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization + + Args: + force_refresh: Whether to force refresh the cache regardless of expiry + + Returns: + Dictionary containing sync results + """ + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 0, + 'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization' + } + ) + + logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}") + + # Import here to avoid circular dependencies + from cve2capec_client import CVE2CAPECClient + + # Create client instance + client = CVE2CAPECClient() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 10, + 'message': 'Fetching MITRE ATT&CK mappings...' + } + ) + + # Force refresh if requested + if force_refresh: + client._fetch_fresh_data() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 50, + 'message': 'Processing CVE mappings...' + } + ) + + # Get statistics about the loaded data + stats = client.get_stats() + + # Update progress to completion + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'sync_cve2capec', + 'progress': 100, + 'message': 'CVE2CAPEC synchronization completed successfully' + } + ) + + result = { + 'status': 'completed', + 'total_mappings': stats.get('total_mappings', 0), + 'total_techniques': stats.get('unique_techniques', 0), + 'cache_updated': True + } + + logger.info(f"CVE2CAPEC sync task completed: {result}") + return result + + except Exception as e: + logger.error(f"CVE2CAPEC sync task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + + +@celery_app.task(bind=True, name='data_sync_tasks.build_exploitdb_index') +def build_exploitdb_index_task(self) -> Dict[str, Any]: + """ + Celery task for building/rebuilding ExploitDB file index + + Returns: + Dictionary containing build results + """ + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'build_exploitdb_index', + 'progress': 0, + 'message': 'Starting ExploitDB file index building' + } + ) + + logger.info("Starting ExploitDB index build task") + + # Import here to avoid circular dependencies + from exploitdb_client_local import ExploitDBLocalClient + + # Create client instance with lazy_load=False to force index building + client = ExploitDBLocalClient(None, lazy_load=False) + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'build_exploitdb_index', + 'progress': 50, + 'message': 'Building file index...' + } + ) + + # Force index rebuild + client._build_file_index() + + # Update progress to completion + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'build_exploitdb_index', + 'progress': 100, + 'message': 'ExploitDB index building completed successfully' + } + ) + + result = { + 'status': 'completed', + 'total_exploits_indexed': len(client.file_index), + 'index_updated': True + } + + logger.info(f"ExploitDB index build task completed: {result}") + return result + + except Exception as e: + logger.error(f"ExploitDB index build task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise diff --git a/backend/tasks/maintenance_tasks.py b/backend/tasks/maintenance_tasks.py new file mode 100644 index 0000000..ccbd085 --- /dev/null +++ b/backend/tasks/maintenance_tasks.py @@ -0,0 +1,437 @@ +""" +Maintenance tasks for Celery +""" +import logging +from datetime import datetime, timedelta +from typing import Dict, Any +from celery_config import celery_app, get_db_session + +logger = logging.getLogger(__name__) + +@celery_app.task(name='tasks.maintenance_tasks.cleanup_old_results') +def cleanup_old_results(): + """ + Periodic task to clean up old Celery results and logs + """ + try: + logger.info("Starting cleanup of old Celery results") + + # This would clean up old results from Redis + # For now, we'll just log the action + cutoff_date = datetime.utcnow() - timedelta(days=7) + + # Clean up old task results (this would be Redis cleanup) + # celery_app.backend.cleanup() + + logger.info(f"Cleanup completed for results older than {cutoff_date}") + + return { + 'status': 'completed', + 'cutoff_date': cutoff_date.isoformat(), + 'message': 'Old results cleanup completed' + } + + except Exception as e: + logger.error(f"Cleanup task failed: {e}") + raise + +@celery_app.task(name='tasks.maintenance_tasks.health_check') +def health_check(): + """ + Health check task to verify system components + """ + try: + db_session = get_db_session() + + # Check database connectivity + try: + db_session.execute("SELECT 1") + db_status = "healthy" + except Exception as e: + db_status = f"unhealthy: {e}" + finally: + db_session.close() + + # Check Redis connectivity + try: + celery_app.backend.ping() + redis_status = "healthy" + except Exception as e: + redis_status = f"unhealthy: {e}" + + result = { + 'timestamp': datetime.utcnow().isoformat(), + 'database': db_status, + 'redis': redis_status, + 'celery': 'healthy' + } + + logger.info(f"Health check completed: {result}") + return result + + except Exception as e: + logger.error(f"Health check failed: {e}") + raise + +@celery_app.task(bind=True, name='tasks.maintenance_tasks.database_cleanup_comprehensive') +def database_cleanup_comprehensive(self, days_to_keep: int = 30, cleanup_failed_jobs: bool = True, + cleanup_logs: bool = True) -> Dict[str, Any]: + """ + Comprehensive database cleanup task + + Args: + days_to_keep: Number of days to keep old records + cleanup_failed_jobs: Whether to clean up failed job records + cleanup_logs: Whether to clean up old log entries + + Returns: + Dictionary containing cleanup results + """ + try: + from datetime import datetime, timedelta + from typing import Dict, Any + + db_session = get_db_session() + + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'database_cleanup', + 'progress': 0, + 'message': 'Starting comprehensive database cleanup' + } + ) + + logger.info(f"Starting comprehensive database cleanup - keeping {days_to_keep} days") + + cutoff_date = datetime.utcnow() - timedelta(days=days_to_keep) + cleanup_results = { + 'cutoff_date': cutoff_date.isoformat(), + 'cleaned_tables': {}, + 'total_records_cleaned': 0 + } + + try: + # Import models here to avoid circular imports + import sys + import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from main import BulkProcessingJob + + # Clean up old bulk processing jobs + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'database_cleanup', + 'progress': 20, + 'message': 'Cleaning up old bulk processing jobs' + } + ) + + old_jobs_query = db_session.query(BulkProcessingJob).filter( + BulkProcessingJob.created_at < cutoff_date + ) + + if cleanup_failed_jobs: + # Clean all old jobs + old_jobs_count = old_jobs_query.count() + old_jobs_query.delete() + else: + # Only clean completed jobs + old_jobs_query = old_jobs_query.filter( + BulkProcessingJob.status.in_(['completed', 'cancelled']) + ) + old_jobs_count = old_jobs_query.count() + old_jobs_query.delete() + + cleanup_results['cleaned_tables']['bulk_processing_jobs'] = old_jobs_count + cleanup_results['total_records_cleaned'] += old_jobs_count + + # Clean up old Celery task results from Redis + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'database_cleanup', + 'progress': 40, + 'message': 'Cleaning up old Celery task results' + } + ) + + try: + # This would clean up old results from Redis backend + # For now, we'll simulate this + celery_cleanup_count = 0 + # celery_app.backend.cleanup() + cleanup_results['cleaned_tables']['celery_results'] = celery_cleanup_count + except Exception as e: + logger.warning(f"Could not clean Celery results: {e}") + cleanup_results['cleaned_tables']['celery_results'] = 0 + + # Clean up old temporary data (if any) + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'database_cleanup', + 'progress': 60, + 'message': 'Cleaning up temporary data' + } + ) + + # Add any custom temporary table cleanup here + # Example: Clean up old session data, temporary files, etc. + temp_cleanup_count = 0 + cleanup_results['cleaned_tables']['temporary_data'] = temp_cleanup_count + + # Vacuum/optimize database (PostgreSQL) + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'database_cleanup', + 'progress': 80, + 'message': 'Optimizing database' + } + ) + + try: + # Run VACUUM on PostgreSQL to reclaim space + db_session.execute("VACUUM;") + cleanup_results['database_optimized'] = True + except Exception as e: + logger.warning(f"Could not vacuum database: {e}") + cleanup_results['database_optimized'] = False + + # Commit all changes + db_session.commit() + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': f'Database cleanup completed - removed {cleanup_results["total_records_cleaned"]} records', + 'results': cleanup_results + } + ) + + logger.info(f"Database cleanup completed: {cleanup_results}") + return cleanup_results + + finally: + db_session.close() + + except Exception as e: + logger.error(f"Database cleanup failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Cleanup failed: {str(e)}', + 'error': str(e) + } + ) + raise + +@celery_app.task(bind=True, name='tasks.maintenance_tasks.health_check_detailed') +def health_check_detailed(self) -> Dict[str, Any]: + """ + Detailed health check task for all system components + + Returns: + Dictionary containing detailed health status + """ + try: + from datetime import datetime + import psutil + import redis + + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'health_check', + 'progress': 0, + 'message': 'Starting detailed health check' + } + ) + + logger.info("Starting detailed health check") + + health_status = { + 'timestamp': datetime.utcnow().isoformat(), + 'overall_status': 'healthy', + 'components': {} + } + + # Check database connectivity and performance + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'health_check', + 'progress': 20, + 'message': 'Checking database health' + } + ) + + db_session = get_db_session() + try: + start_time = datetime.utcnow() + db_session.execute("SELECT 1") + db_response_time = (datetime.utcnow() - start_time).total_seconds() + + # Check database size and connections + db_size_result = db_session.execute("SELECT pg_size_pretty(pg_database_size(current_database()));").fetchone() + db_connections_result = db_session.execute("SELECT count(*) FROM pg_stat_activity;").fetchone() + + health_status['components']['database'] = { + 'status': 'healthy', + 'response_time_seconds': db_response_time, + 'database_size': db_size_result[0] if db_size_result else 'unknown', + 'active_connections': db_connections_result[0] if db_connections_result else 0, + 'details': 'Database responsive and accessible' + } + except Exception as e: + health_status['components']['database'] = { + 'status': 'unhealthy', + 'error': str(e), + 'details': 'Database connection failed' + } + health_status['overall_status'] = 'degraded' + finally: + db_session.close() + + # Check Redis connectivity and performance + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'health_check', + 'progress': 40, + 'message': 'Checking Redis health' + } + ) + + try: + start_time = datetime.utcnow() + celery_app.backend.ping() + redis_response_time = (datetime.utcnow() - start_time).total_seconds() + + # Get Redis info + redis_client = redis.Redis.from_url(celery_app.conf.broker_url) + redis_info = redis_client.info() + + health_status['components']['redis'] = { + 'status': 'healthy', + 'response_time_seconds': redis_response_time, + 'memory_usage_mb': redis_info.get('used_memory', 0) / (1024 * 1024), + 'connected_clients': redis_info.get('connected_clients', 0), + 'uptime_seconds': redis_info.get('uptime_in_seconds', 0), + 'details': 'Redis responsive and accessible' + } + except Exception as e: + health_status['components']['redis'] = { + 'status': 'unhealthy', + 'error': str(e), + 'details': 'Redis connection failed' + } + health_status['overall_status'] = 'degraded' + + # Check system resources + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'health_check', + 'progress': 60, + 'message': 'Checking system resources' + } + ) + + try: + cpu_percent = psutil.cpu_percent(interval=1) + memory = psutil.virtual_memory() + disk = psutil.disk_usage('/') + + health_status['components']['system'] = { + 'status': 'healthy', + 'cpu_percent': cpu_percent, + 'memory_percent': memory.percent, + 'memory_available_gb': memory.available / (1024**3), + 'disk_percent': disk.percent, + 'disk_free_gb': disk.free / (1024**3), + 'details': 'System resources within normal ranges' + } + + # Mark as degraded if resources are high + if cpu_percent > 80 or memory.percent > 85 or disk.percent > 90: + health_status['components']['system']['status'] = 'degraded' + health_status['overall_status'] = 'degraded' + health_status['components']['system']['details'] = 'High resource usage detected' + + except Exception as e: + health_status['components']['system'] = { + 'status': 'unknown', + 'error': str(e), + 'details': 'Could not check system resources' + } + + # Check Celery worker status + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'health_check', + 'progress': 80, + 'message': 'Checking Celery workers' + } + ) + + try: + inspect = celery_app.control.inspect() + active_workers = inspect.active() + stats = inspect.stats() + + health_status['components']['celery'] = { + 'status': 'healthy', + 'active_workers': len(active_workers) if active_workers else 0, + 'worker_stats': stats, + 'details': 'Celery workers responding' + } + + if not active_workers: + health_status['components']['celery']['status'] = 'degraded' + health_status['components']['celery']['details'] = 'No active workers found' + health_status['overall_status'] = 'degraded' + + except Exception as e: + health_status['components']['celery'] = { + 'status': 'unknown', + 'error': str(e), + 'details': 'Could not check Celery workers' + } + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': f'Health check completed - overall status: {health_status["overall_status"]}', + 'results': health_status + } + ) + + logger.info(f"Detailed health check completed: {health_status['overall_status']}") + return health_status + + except Exception as e: + logger.error(f"Detailed health check failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Health check failed: {str(e)}', + 'error': str(e) + } + ) + raise \ No newline at end of file diff --git a/backend/tasks/sigma_tasks.py b/backend/tasks/sigma_tasks.py new file mode 100644 index 0000000..be481e7 --- /dev/null +++ b/backend/tasks/sigma_tasks.py @@ -0,0 +1,409 @@ +""" +SIGMA rule generation tasks for Celery +""" +import asyncio +import logging +from typing import Dict, Any, List, Optional +from celery import current_task +from celery_config import celery_app, get_db_session +from enhanced_sigma_generator import EnhancedSigmaGenerator +from llm_client import LLMClient + +logger = logging.getLogger(__name__) + +@celery_app.task(bind=True, name='sigma_tasks.generate_enhanced_rules') +def generate_enhanced_rules_task(self, cve_ids: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Celery task for enhanced SIGMA rule generation + + Args: + cve_ids: Optional list of specific CVE IDs to process + + Returns: + Dictionary containing generation results + """ + db_session = get_db_session() + + try: + # Import here to avoid circular imports + import sys + import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from main import CVE + + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'generating_rules', + 'progress': 0, + 'message': 'Starting enhanced SIGMA rule generation' + } + ) + + logger.info(f"Starting enhanced rule generation task for CVEs: {cve_ids}") + + # Create generator instance + generator = EnhancedSigmaGenerator(db_session) + + # Get CVEs to process + if cve_ids: + cves = db_session.query(CVE).filter(CVE.cve_id.in_(cve_ids)).all() + else: + cves = db_session.query(CVE).filter(CVE.poc_count > 0).all() + + total_cves = len(cves) + processed_cves = 0 + successful_rules = 0 + failed_rules = 0 + results = [] + + # Process each CVE + for i, cve in enumerate(cves): + try: + # Update progress + progress = int((i / total_cves) * 100) + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'generating_rules', + 'progress': progress, + 'message': f'Processing CVE {cve.cve_id}', + 'current_cve': cve.cve_id, + 'processed': processed_cves, + 'total': total_cves + } + ) + + # Generate rule using asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + generator.generate_enhanced_rule(cve) + ) + + if result.get('success', False): + successful_rules += 1 + else: + failed_rules += 1 + + results.append({ + 'cve_id': cve.cve_id, + 'success': result.get('success', False), + 'message': result.get('message', 'No message'), + 'rule_id': result.get('rule_id') + }) + + finally: + loop.close() + + processed_cves += 1 + + except Exception as e: + logger.error(f"Error processing CVE {cve.cve_id}: {e}") + failed_rules += 1 + results.append({ + 'cve_id': cve.cve_id, + 'success': False, + 'message': f'Error: {str(e)}', + 'rule_id': None + }) + + # Final results + final_result = { + 'total_processed': processed_cves, + 'successful_rules': successful_rules, + 'failed_rules': failed_rules, + 'results': results + } + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': f'Generated {successful_rules} rules from {processed_cves} CVEs', + 'results': final_result + } + ) + + logger.info(f"Enhanced rule generation task completed: {final_result}") + return final_result + + except Exception as e: + logger.error(f"Enhanced rule generation task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() + +@celery_app.task(bind=True, name='sigma_tasks.llm_enhanced_generation') +def llm_enhanced_generation_task(self, cve_id: str, provider: str = 'ollama', + model: Optional[str] = None) -> Dict[str, Any]: + """ + Celery task for LLM-enhanced rule generation + + Args: + cve_id: CVE identifier + provider: LLM provider (openai, anthropic, ollama, finetuned) + model: Specific model to use + + Returns: + Dictionary containing generation result + """ + db_session = get_db_session() + + try: + # Import here to avoid circular imports + import sys + import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from main import CVE + + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'llm_generation', + 'progress': 10, + 'message': f'Starting LLM rule generation for {cve_id}', + 'cve_id': cve_id, + 'provider': provider, + 'model': model + } + ) + + logger.info(f"Starting LLM rule generation for {cve_id} using {provider}") + + # Get CVE from database + cve = db_session.query(CVE).filter(CVE.cve_id == cve_id).first() + if not cve: + raise ValueError(f"CVE {cve_id} not found in database") + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'llm_generation', + 'progress': 25, + 'message': f'Initializing LLM client ({provider})', + 'cve_id': cve_id + } + ) + + # Create LLM client + llm_client = LLMClient(provider=provider, model=model) + + if not llm_client.is_available(): + raise ValueError(f"LLM client not available for provider: {provider}") + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'llm_generation', + 'progress': 50, + 'message': f'Generating rule with LLM for {cve_id}', + 'cve_id': cve_id + } + ) + + # Generate rule using asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + rule_content = loop.run_until_complete( + llm_client.generate_sigma_rule( + cve_id=cve.cve_id, + poc_content=cve.poc_data or '', + cve_description=cve.description or '' + ) + ) + finally: + loop.close() + + # Update progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'llm_generation', + 'progress': 75, + 'message': f'Validating generated rule for {cve_id}', + 'cve_id': cve_id + } + ) + + # Validate the generated rule + is_valid = False + if rule_content: + is_valid = llm_client.validate_sigma_rule(rule_content, cve_id) + + # Prepare result + result = { + 'cve_id': cve_id, + 'rule_content': rule_content, + 'is_valid': is_valid, + 'provider': provider, + 'model': model or llm_client.model, + 'success': bool(rule_content and is_valid) + } + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': f'LLM rule generation completed for {cve_id}', + 'cve_id': cve_id, + 'success': result['success'], + 'result': result + } + ) + + logger.info(f"LLM rule generation task completed for {cve_id}: {result['success']}") + return result + + except Exception as e: + logger.error(f"LLM rule generation task failed for {cve_id}: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Task failed for {cve_id}: {str(e)}', + 'cve_id': cve_id, + 'error': str(e) + } + ) + raise + finally: + db_session.close() + +@celery_app.task(bind=True, name='sigma_tasks.batch_llm_generation') +def batch_llm_generation_task(self, cve_ids: List[str], provider: str = 'ollama', + model: Optional[str] = None) -> Dict[str, Any]: + """ + Celery task for batch LLM rule generation + + Args: + cve_ids: List of CVE identifiers + provider: LLM provider (openai, anthropic, ollama, finetuned) + model: Specific model to use + + Returns: + Dictionary containing batch generation results + """ + db_session = get_db_session() + + try: + # Update task progress + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'batch_llm_generation', + 'progress': 0, + 'message': f'Starting batch LLM generation for {len(cve_ids)} CVEs', + 'total_cves': len(cve_ids), + 'provider': provider, + 'model': model + } + ) + + logger.info(f"Starting batch LLM generation for {len(cve_ids)} CVEs using {provider}") + + # Initialize results + results = [] + successful_rules = 0 + failed_rules = 0 + + # Process each CVE + for i, cve_id in enumerate(cve_ids): + try: + # Update progress + progress = int((i / len(cve_ids)) * 100) + self.update_state( + state='PROGRESS', + meta={ + 'stage': 'batch_llm_generation', + 'progress': progress, + 'message': f'Processing CVE {cve_id} ({i+1}/{len(cve_ids)})', + 'current_cve': cve_id, + 'processed': i, + 'total': len(cve_ids) + } + ) + + # Generate rule for this CVE + result = llm_enhanced_generation_task.apply( + args=[cve_id, provider, model] + ).get() + + if result.get('success', False): + successful_rules += 1 + else: + failed_rules += 1 + + results.append(result) + + except Exception as e: + logger.error(f"Error processing CVE {cve_id} in batch: {e}") + failed_rules += 1 + results.append({ + 'cve_id': cve_id, + 'success': False, + 'error': str(e), + 'provider': provider, + 'model': model + }) + + # Final results + final_result = { + 'total_processed': len(cve_ids), + 'successful_rules': successful_rules, + 'failed_rules': failed_rules, + 'provider': provider, + 'model': model, + 'results': results + } + + # Update final progress + self.update_state( + state='SUCCESS', + meta={ + 'stage': 'completed', + 'progress': 100, + 'message': f'Batch generation completed: {successful_rules} successful, {failed_rules} failed', + 'results': final_result + } + ) + + logger.info(f"Batch LLM generation task completed: {final_result}") + return final_result + + except Exception as e: + logger.error(f"Batch LLM generation task failed: {e}") + self.update_state( + state='FAILURE', + meta={ + 'stage': 'error', + 'progress': 0, + 'message': f'Batch task failed: {str(e)}', + 'error': str(e) + } + ) + raise + finally: + db_session.close() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index d5a3592..9c62c1c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -22,6 +22,8 @@ services: - "8000:8000" environment: DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 NVD_API_KEY: ${NVD_API_KEY:-} GITHUB_TOKEN: ${GITHUB_TOKEN} OPENAI_API_KEY: ${OPENAI_API_KEY:-} @@ -29,15 +31,21 @@ services: OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://ollama:11434} LLM_PROVIDER: ${LLM_PROVIDER:-ollama} LLM_MODEL: ${LLM_MODEL:-llama3.2} + LLM_ENABLED: ${LLM_ENABLED:-true} + FINETUNED_MODEL_PATH: ${FINETUNED_MODEL_PATH:-/app/models/sigma_llama_finetuned} + HUGGING_FACE_TOKEN: ${HUGGING_FACE_TOKEN} depends_on: db: condition: service_healthy + redis: + condition: service_started ollama-setup: condition: service_completed_successfully volumes: - ./backend:/app - ./github_poc_collector:/github_poc_collector - ./exploit-db-mirror:/app/exploit-db-mirror + - ./models:/app/models command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload frontend: @@ -68,6 +76,12 @@ services: environment: - OLLAMA_HOST=0.0.0.0 restart: unless-stopped + deploy: + resources: + limits: + memory: 5G + reservations: + memory: 3G ollama-setup: build: ./backend @@ -78,9 +92,109 @@ services: LLM_MODEL: llama3.2 volumes: - ./backend:/app - command: python setup_ollama.py + command: python setup_ollama_with_sigma.py restart: "no" + initial-setup: + build: ./backend + depends_on: + db: + condition: service_healthy + redis: + condition: service_started + celery-worker: + condition: service_healthy + environment: + DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + volumes: + - ./backend:/app + command: python initial_setup.py + restart: "no" + + celery-worker: + build: ./backend + command: celery -A celery_config worker --loglevel=info --concurrency=4 + environment: + DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + NVD_API_KEY: ${NVD_API_KEY:-} + GITHUB_TOKEN: ${GITHUB_TOKEN} + OPENAI_API_KEY: ${OPENAI_API_KEY:-} + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-} + OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://ollama:11434} + LLM_PROVIDER: ${LLM_PROVIDER:-ollama} + LLM_MODEL: ${LLM_MODEL:-llama3.2} + LLM_ENABLED: ${LLM_ENABLED:-true} + FINETUNED_MODEL_PATH: ${FINETUNED_MODEL_PATH:-/app/models/sigma_llama_finetuned} + HUGGING_FACE_TOKEN: ${HUGGING_FACE_TOKEN} + depends_on: + db: + condition: service_healthy + redis: + condition: service_started + ollama-setup: + condition: service_completed_successfully + volumes: + - ./backend:/app + - ./github_poc_collector:/github_poc_collector + - ./exploit-db-mirror:/app/exploit-db-mirror + - ./models:/app/models + restart: unless-stopped + healthcheck: + test: ["CMD", "celery", "-A", "celery_config", "inspect", "ping"] + interval: 30s + timeout: 10s + retries: 3 + + celery-beat: + build: ./backend + command: celery -A celery_config beat --loglevel=info --pidfile=/tmp/celerybeat.pid + environment: + DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + NVD_API_KEY: ${NVD_API_KEY:-} + GITHUB_TOKEN: ${GITHUB_TOKEN} + OPENAI_API_KEY: ${OPENAI_API_KEY:-} + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-} + OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://ollama:11434} + LLM_PROVIDER: ${LLM_PROVIDER:-ollama} + LLM_MODEL: ${LLM_MODEL:-llama3.2} + LLM_ENABLED: ${LLM_ENABLED:-true} + FINETUNED_MODEL_PATH: ${FINETUNED_MODEL_PATH:-/app/models/sigma_llama_finetuned} + HUGGING_FACE_TOKEN: ${HUGGING_FACE_TOKEN} + depends_on: + db: + condition: service_healthy + redis: + condition: service_started + celery-worker: + condition: service_healthy + volumes: + - ./backend:/app + - ./github_poc_collector:/github_poc_collector + - ./exploit-db-mirror:/app/exploit-db-mirror + - ./models:/app/models + restart: unless-stopped + + flower: + build: ./backend + command: celery -A celery_config flower --port=5555 + ports: + - "5555:5555" + environment: + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + depends_on: + redis: + condition: service_started + celery-worker: + condition: service_healthy + restart: unless-stopped + volumes: postgres_data: redis_data: diff --git a/frontend/src/App.js b/frontend/src/App.js index 4f24ccd..fa97ea3 100644 --- a/frontend/src/App.js +++ b/frontend/src/App.js @@ -26,8 +26,6 @@ function App() { const [runningJobTypes, setRunningJobTypes] = useState(new Set()); const [llmStatus, setLlmStatus] = useState({}); const [exploitSyncDropdownOpen, setExploitSyncDropdownOpen] = useState(false); - const [schedulerStatus, setSchedulerStatus] = useState({}); - const [schedulerJobs, setSchedulerJobs] = useState({}); useEffect(() => { fetchData(); @@ -84,15 +82,8 @@ function App() { return isNomiSecSyncRunning() || isGitHubPocSyncRunning() || isExploitDBSyncRunning() || isCISAKEVSyncRunning(); }; - const fetchSchedulerData = async () => { - try { - const response = await axios.get('http://localhost:8000/api/scheduler/status'); - setSchedulerStatus(response.data.scheduler_status); - setSchedulerJobs(response.data.scheduler_status.jobs || {}); - } catch (error) { - console.error('Error fetching scheduler data:', error); - } - }; + // Note: Scheduler functionality removed - now handled by Celery Beat + // Monitoring available via Flower at http://localhost:5555 const fetchData = async () => { try { @@ -223,12 +214,25 @@ function App() { const syncGitHubPocs = async (cveId = null) => { try { const response = await axios.post(`${API_BASE_URL}/api/sync-github-pocs`, { - cve_id: cveId + cve_id: cveId, + batch_size: 50 }); console.log('GitHub PoC sync response:', response.data); + + // Show success message with Celery task info + if (response.data.task_id) { + console.log(`GitHub PoC sync task started: ${response.data.task_id}`); + console.log(`Monitor at: ${response.data.monitor_url}`); + + // Show notification to user + alert(`GitHub PoC sync started successfully!\nTask ID: ${response.data.task_id}\n\nMonitor progress at http://localhost:5555 (Flower Dashboard)`); + } + + // Refresh data to show the task in bulk jobs fetchData(); } catch (error) { console.error('Error syncing GitHub PoCs:', error); + alert('Failed to start GitHub PoC sync. Please check the console for details.'); } }; @@ -331,6 +335,35 @@ function App() { const Dashboard = () => (
+ {/* Celery Monitoring Notice */} +
+
+
+ + + +
+
+

+ Task Scheduling & Monitoring +

+
+

+ Automated tasks are now managed by Celery Beat. Monitor real-time task execution at{' '} + + Flower Dashboard + +

+
+
+
+
+

Total CVEs

@@ -1183,268 +1216,8 @@ function App() {
); - const SchedulerManager = () => { - const [selectedJob, setSelectedJob] = useState(null); - const [newSchedule, setNewSchedule] = useState(''); - const [showScheduleEdit, setShowScheduleEdit] = useState(false); - - useEffect(() => { - fetchSchedulerData(); - const interval = setInterval(fetchSchedulerData, 30000); // Refresh every 30 seconds - return () => clearInterval(interval); - }, []); - - const controlScheduler = async (action) => { - try { - const response = await axios.post('http://localhost:8000/api/scheduler/control', { - action: action - }); - console.log('Scheduler control response:', response.data); - fetchSchedulerData(); - } catch (error) { - console.error('Error controlling scheduler:', error); - alert('Error controlling scheduler: ' + (error.response?.data?.detail || error.message)); - } - }; - - const controlJob = async (jobName, action) => { - try { - const response = await axios.post('http://localhost:8000/api/scheduler/job/control', { - job_name: jobName, - action: action - }); - console.log('Job control response:', response.data); - fetchSchedulerData(); - } catch (error) { - console.error('Error controlling job:', error); - alert('Error controlling job: ' + (error.response?.data?.detail || error.message)); - } - }; - - const updateJobSchedule = async (jobName, schedule) => { - try { - const response = await axios.post('http://localhost:8000/api/scheduler/job/schedule', { - job_name: jobName, - schedule: schedule - }); - console.log('Schedule update response:', response.data); - setShowScheduleEdit(false); - setNewSchedule(''); - setSelectedJob(null); - fetchSchedulerData(); - } catch (error) { - console.error('Error updating schedule:', error); - alert('Error updating schedule: ' + (error.response?.data?.detail || error.message)); - } - }; - - const formatNextRun = (dateString) => { - if (!dateString) return 'Not scheduled'; - const date = new Date(dateString); - return date.toLocaleString(); - }; - - const getStatusColor = (status) => { - switch (status) { - case true: - case 'enabled': - return 'text-green-600'; - case false: - case 'disabled': - return 'text-red-600'; - case 'running': - return 'text-blue-600'; - default: - return 'text-gray-600'; - } - }; - - return ( -
- {/* Scheduler Status */} -
-
-

Job Scheduler Status

-
-
-
-
-

Scheduler Status

-

- {schedulerStatus.scheduler_running ? 'Running' : 'Stopped'} -

-
-
-

Total Jobs

-

{schedulerStatus.total_jobs || 0}

-
-
-

Enabled Jobs

-

{schedulerStatus.enabled_jobs || 0}

-
-
-

Running Jobs

-

{schedulerStatus.running_jobs || 0}

-
-
- -
- - - -
-
-
- - {/* Scheduled Jobs */} -
-
-

Scheduled Jobs

-
-
- {Object.entries(schedulerJobs).map(([jobName, job]) => ( -
-
-
-
-

{jobName}

- - {job.enabled ? 'Enabled' : 'Disabled'} - - {job.is_running && ( - - Running - - )} -
-

{job.description}

-
-
- Schedule: {job.schedule} -
-
- Next Run: {formatNextRun(job.next_run)} -
-
- Run Count: {job.run_count} -
-
- Failures: {job.failure_count} -
-
-
-
- - - -
-
-
- ))} -
-
- - {/* Schedule Edit Modal */} - {showScheduleEdit && selectedJob && ( -
-
-

- Edit Schedule for {selectedJob} -

-
- - setNewSchedule(e.target.value)} - className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500" - placeholder="0 */6 * * *" - /> -

- Format: minute hour day_of_month month day_of_week -

-
-
- - -
-
-
- )} -
- ); - }; + // Note: SchedulerManager component removed - job scheduling now handled by Celery Beat + // Task monitoring available via Flower dashboard at http://localhost:5555 if (loading) { return ( @@ -1507,16 +1280,6 @@ function App() { > Bulk Jobs -
@@ -1529,7 +1292,6 @@ function App() { {activeTab === 'cves' && } {activeTab === 'rules' && } {activeTab === 'bulk-jobs' && } - {activeTab === 'scheduler' && }