Optimize performance and migrate to Celery-based scheduling
This commit introduces major performance improvements and migrates from custom job scheduling to Celery Beat for better reliability and scalability. ### 🚀 Performance Optimizations **CVE2CAPEC Client Performance (Fixed startup blocking)** - Implement lazy loading with 24-hour cache for CVE2CAPEC mappings - Add background task for CVE2CAPEC sync (data_sync_tasks.sync_cve2capec) - Remove blocking data fetch during client initialization - API endpoint: POST /api/sync-cve2capec **ExploitDB Client Performance (Fixed webapp request blocking)** - Implement global file index cache to prevent rebuilding on every request - Add lazy loading with 24-hour cache expiry for 46K+ exploit index - Background task for index building (data_sync_tasks.build_exploitdb_index) - API endpoint: POST /api/build-exploitdb-index ### 🔄 Celery Migration & Scheduling **Celery Beat Integration** - Migrate from custom job scheduler to Celery Beat for reliability - Remove 'finetuned' LLM provider (logic moved to ollama container) - Optimized daily workflow with proper timing and dependencies **New Celery Tasks Structure** - tasks/bulk_tasks.py - NVD bulk processing and SIGMA generation - tasks/data_sync_tasks.py - All data synchronization tasks - tasks/maintenance_tasks.py - System maintenance and cleanup - tasks/sigma_tasks.py - SIGMA rule generation tasks **Daily Schedule (Optimized)** ``` 1:00 AM → Weekly cleanup (Sundays) 1:30 AM → Daily result cleanup 2:00 AM → NVD incremental update 3:00 AM → CISA KEV sync 3:15 AM → Nomi-sec PoC sync 3:30 AM → GitHub PoC sync 3:45 AM → ExploitDB sync 4:00 AM → CVE2CAPEC MITRE ATT&CK sync 4:15 AM → ExploitDB index rebuild 5:00 AM → Reference content sync 8:00 AM → SIGMA rule generation 9:00 AM → LLM-enhanced SIGMA generation Every 15min → Health checks ``` ### 🐳 Docker & Infrastructure **Enhanced Docker Setup** - Ollama setup with integrated SIGMA model creation (setup_ollama_with_sigma.py) - Initial database population check and trigger (initial_setup.py) - Proper service dependencies and health checks - Remove manual post-rebuild script requirements **Service Architecture** - Celery worker with 4-queue system (default, bulk_processing, sigma_generation, data_sync) - Flower monitoring dashboard (localhost:5555) - Redis as message broker and result backend ### 🎯 API Improvements **Background Task Endpoints** - GitHub PoC sync now uses Celery (was blocking backend) - All sync operations return task IDs and monitoring URLs - Consistent error handling and progress tracking **New Endpoints** - POST /api/sync-cve2capec - CVE2CAPEC mapping sync - POST /api/build-exploitdb-index - ExploitDB index rebuild ### 📁 Cleanup **Removed Files** - fix_sigma_model.sh (replaced by setup_ollama_with_sigma.py) - Various test_* and debug_* files no longer needed - Old training scripts related to removed 'finetuned' provider - Utility scripts replaced by Docker services ### 🔧 Configuration **Key Files Added/Modified** - backend/celery_config.py - Complete Celery configuration - backend/initial_setup.py - First-boot database population - backend/setup_ollama_with_sigma.py - Integrated Ollama setup - CLAUDE.md - Project documentation and development guide 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
54db665711
commit
9bde1395bf
15 changed files with 3402 additions and 693 deletions
282
CLAUDE.md
Normal file
282
CLAUDE.md
Normal file
|
@ -0,0 +1,282 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is an enhanced CVE-SIGMA Auto Generator that automatically processes comprehensive CVE data and generates SIGMA rules for threat detection. The application now supports:
|
||||
|
||||
1. **Bulk NVD Data Processing**: Downloads and processes complete NVD JSON datasets (2002-2025)
|
||||
2. **nomi-sec PoC Integration**: Uses curated PoC data from github.com/nomi-sec/PoC-in-GitHub
|
||||
3. **Enhanced SIGMA Rule Generation**: Creates intelligent rules based on real exploit indicators
|
||||
4. **Comprehensive Database Seeding**: Supports both bulk and incremental data updates
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Backend**: FastAPI with SQLAlchemy ORM (`backend/main.py`)
|
||||
- **Frontend**: React with Tailwind CSS (`frontend/src/App.js`)
|
||||
- **Database**: PostgreSQL with enhanced schema:
|
||||
- `cves`: CVE information with PoC metadata and bulk processing fields
|
||||
- `sigma_rules`: Enhanced SIGMA rules with quality scoring and nomi-sec data
|
||||
- `rule_templates`: Template patterns for rule generation
|
||||
- `bulk_processing_jobs`: Job tracking for bulk operations
|
||||
- **Data Processing**:
|
||||
- `nvd_bulk_processor.py`: NVD JSON dataset downloader and processor
|
||||
- `nomi_sec_client.py`: nomi-sec PoC-in-GitHub API integration
|
||||
- `enhanced_sigma_generator.py`: Advanced SIGMA rule generation
|
||||
- `bulk_seeder.py`: Coordinated bulk seeding operations
|
||||
- **Cache**: Redis (optional)
|
||||
- **Deployment**: Docker Compose orchestration
|
||||
|
||||
## Common Development Commands
|
||||
|
||||
### Quick Start
|
||||
```bash
|
||||
# Recommended quick start
|
||||
chmod +x start.sh
|
||||
./start.sh
|
||||
|
||||
# Or using Make
|
||||
make start
|
||||
```
|
||||
|
||||
### Build and Run
|
||||
```bash
|
||||
# Build and start all services
|
||||
docker-compose up -d --build
|
||||
|
||||
# Start individual services
|
||||
docker-compose up -d db redis # Database and cache only
|
||||
docker-compose up -d backend # Backend API
|
||||
docker-compose up -d frontend # React frontend
|
||||
```
|
||||
|
||||
### Development Mode
|
||||
```bash
|
||||
# Using Make
|
||||
make dev
|
||||
|
||||
# Or manually
|
||||
docker-compose up -d db redis
|
||||
cd backend && pip install -r requirements.txt && uvicorn main:app --reload
|
||||
cd frontend && npm install && npm start
|
||||
```
|
||||
|
||||
### Bulk Processing Commands
|
||||
```bash
|
||||
# Run bulk seeding standalone
|
||||
cd backend && python bulk_seeder.py
|
||||
|
||||
# Bulk seed specific year range
|
||||
cd backend && python -c "
|
||||
import asyncio
|
||||
from bulk_seeder import BulkSeeder
|
||||
from main import SessionLocal
|
||||
seeder = BulkSeeder(SessionLocal())
|
||||
asyncio.run(seeder.full_bulk_seed(start_year=2020, end_year=2025))
|
||||
"
|
||||
|
||||
# Incremental update only
|
||||
cd backend && python -c "
|
||||
import asyncio
|
||||
from bulk_seeder import BulkSeeder
|
||||
from main import SessionLocal
|
||||
seeder = BulkSeeder(SessionLocal())
|
||||
asyncio.run(seeder.incremental_update())
|
||||
"
|
||||
```
|
||||
|
||||
### Frontend Commands
|
||||
```bash
|
||||
cd frontend
|
||||
npm install # Install dependencies
|
||||
npm start # Development server (port 3000)
|
||||
npm run build # Production build
|
||||
npm test # Run tests
|
||||
```
|
||||
|
||||
### Backend Commands
|
||||
```bash
|
||||
cd backend
|
||||
pip install -r requirements.txt
|
||||
uvicorn main:app --reload # Development server (port 8000)
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 # Production server
|
||||
```
|
||||
|
||||
### Database Operations
|
||||
```bash
|
||||
# Connect to database
|
||||
docker-compose exec db psql -U cve_user -d cve_sigma_db
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f backend
|
||||
docker-compose logs -f frontend
|
||||
```
|
||||
|
||||
### Other Make Commands
|
||||
```bash
|
||||
make stop # Stop all services
|
||||
make restart # Restart all services
|
||||
make logs # View application logs
|
||||
make clean # Clean up containers and volumes
|
||||
make setup # Initial setup (creates .env from .env.example)
|
||||
```
|
||||
|
||||
## Key Configuration
|
||||
|
||||
### Environment Variables (.env)
|
||||
- `NVD_API_KEY`: Optional NVD API key for higher rate limits (5→50 requests/30s)
|
||||
- `GITHUB_TOKEN`: Optional GitHub token for exploit analysis (enhances rule generation)
|
||||
- `OPENAI_API_KEY`: Optional OpenAI API key for AI-enhanced SIGMA rule generation
|
||||
- `ANTHROPIC_API_KEY`: Optional Anthropic API key for AI-enhanced SIGMA rule generation
|
||||
- `OLLAMA_BASE_URL`: Optional Ollama base URL for local model AI-enhanced SIGMA rule generation
|
||||
- `LLM_PROVIDER`: Optional LLM provider selection (openai, anthropic, ollama)
|
||||
- `LLM_MODEL`: Optional LLM model selection (provider-specific)
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- `REACT_APP_API_URL`: Backend API URL for frontend
|
||||
|
||||
### Service URLs
|
||||
- Frontend: http://localhost:3000
|
||||
- Backend API: http://localhost:8000
|
||||
- API Documentation: http://localhost:8000/docs
|
||||
- Database: localhost:5432
|
||||
- Redis: localhost:6379
|
||||
|
||||
### Enhanced API Endpoints
|
||||
|
||||
#### Bulk Processing
|
||||
- `POST /api/bulk-seed` - Start complete bulk seeding (NVD + nomi-sec)
|
||||
- `POST /api/incremental-update` - Update with NVD modified/recent feeds
|
||||
- `POST /api/sync-nomi-sec` - Synchronize nomi-sec PoC data
|
||||
- `POST /api/regenerate-rules` - Regenerate SIGMA rules with enhanced data
|
||||
- `GET /api/bulk-jobs` - Get bulk processing job status
|
||||
- `GET /api/bulk-status` - Get comprehensive system status
|
||||
- `GET /api/poc-stats` - Get PoC-related statistics
|
||||
|
||||
#### Enhanced Data Access
|
||||
- `GET /api/stats` - Enhanced statistics with PoC coverage
|
||||
- `GET /api/claude-status` - Get Claude API availability status
|
||||
- All existing CVE and SIGMA rule endpoints now include enhanced data fields
|
||||
|
||||
#### LLM-Enhanced Rule Generation
|
||||
- `POST /api/llm-enhanced-rules` - Generate SIGMA rules using LLM AI analysis (supports multiple providers)
|
||||
- `GET /api/llm-status` - Check LLM API availability and configuration for all providers
|
||||
- `POST /api/llm-switch` - Switch between LLM providers and models
|
||||
|
||||
## Code Architecture Details
|
||||
|
||||
### Enhanced Backend Structure
|
||||
- **main.py**: Core FastAPI application with enhanced endpoints
|
||||
- **nvd_bulk_processor.py**: NVD JSON dataset downloader and processor
|
||||
- **nomi_sec_client.py**: nomi-sec PoC-in-GitHub API integration
|
||||
- **enhanced_sigma_generator.py**: Advanced SIGMA rule generation with PoC data
|
||||
- **llm_client.py**: Multi-provider LLM integration using LangChain for AI-enhanced rule generation
|
||||
- **bulk_seeder.py**: Coordinated bulk processing operations
|
||||
|
||||
### Database Models (Enhanced)
|
||||
- **CVE**: Enhanced with `poc_count`, `poc_data`, `bulk_processed`, `data_source`
|
||||
- **SigmaRule**: Enhanced with `poc_source`, `poc_quality_score`, `nomi_sec_data`
|
||||
- **RuleTemplate**: Template patterns for rule generation
|
||||
- **BulkProcessingJob**: Job tracking for bulk operations
|
||||
|
||||
### Frontend Structure (Enhanced)
|
||||
- **Four Main Tabs**: Dashboard, CVEs, SIGMA Rules, Bulk Jobs
|
||||
- **Enhanced Dashboard**: PoC coverage statistics, bulk processing controls
|
||||
- **Bulk Jobs Tab**: Real-time job monitoring and system status
|
||||
- **Enhanced CVE/Rule Display**: PoC quality indicators, exploit-based tagging
|
||||
|
||||
### Data Processing Flow
|
||||
1. **Bulk Seeding**: NVD JSON downloads → Database storage → nomi-sec PoC sync → Enhanced rule generation
|
||||
2. **Incremental Updates**: NVD modified feeds → Update existing data → Sync new PoCs
|
||||
3. **Rule Enhancement**: PoC analysis → Indicator extraction → Template selection → Enhanced SIGMA rule
|
||||
4. **LLM-Enhanced Generation**: PoC content analysis → Multi-provider LLM processing → Advanced SIGMA rule creation
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Enhanced Rule Generation Logic
|
||||
The application now uses an advanced rule generation process:
|
||||
1. **CVE Analysis**: Extract metadata from NVD bulk data
|
||||
2. **PoC Quality Assessment**: nomi-sec PoC analysis with star count, recency, quality tiers
|
||||
3. **Advanced Indicator Extraction**: Processes, files, network, registry, commands from PoC repositories
|
||||
4. **Template Selection**: Smart template matching based on PoC indicators and CVE characteristics
|
||||
5. **Enhanced Rule Population**: Incorporate real exploit indicators with quality scoring
|
||||
6. **MITRE ATT&CK Mapping**: Automatic technique identification based on indicators
|
||||
7. **LLM AI Enhancement**: Optional multi-provider LLM integration for intelligent rule generation from PoC code analysis
|
||||
|
||||
### Quality Tiers
|
||||
- **Excellent** (80+ points): High star count, recent updates, detailed descriptions
|
||||
- **Good** (60-79 points): Moderate quality indicators
|
||||
- **Fair** (40-59 points): Basic PoC with some quality indicators
|
||||
- **Poor** (20-39 points): Minimal quality indicators
|
||||
- **Very Poor** (<20 points): Low-quality PoCs
|
||||
|
||||
### Multi-Provider LLM Integration Features
|
||||
- **Multiple LLM Providers**: Support for OpenAI, Anthropic, and Ollama (local models)
|
||||
- **Dynamic Provider Switching**: Switch between providers and models through UI or API
|
||||
- **Intelligent Code Analysis**: LLMs analyze actual exploit code from PoC repositories
|
||||
- **Advanced Rule Generation**: Creates sophisticated SIGMA rules with proper syntax and logic
|
||||
- **Contextual Understanding**: Interprets CVE descriptions and maps them to appropriate detection patterns
|
||||
- **Automatic Validation**: Generated rules are validated for SIGMA syntax compliance
|
||||
- **Fallback Mechanism**: Automatically falls back to template-based generation if LLM is unavailable
|
||||
- **Enhanced Metadata**: Rules include generation method tracking for quality assessment
|
||||
- **LangChain Integration**: Uses LangChain for robust LLM integration and prompt management
|
||||
|
||||
### Supported LLM Providers and Models
|
||||
|
||||
#### OpenAI
|
||||
- **API Key**: Set `OPENAI_API_KEY` environment variable
|
||||
- **Supported Models**: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
|
||||
- **Default Model**: gpt-4o-mini
|
||||
- **Rate Limits**: Based on OpenAI API limits
|
||||
|
||||
#### Anthropic
|
||||
- **API Key**: Set `ANTHROPIC_API_KEY` environment variable
|
||||
- **Supported Models**: claude-3-5-sonnet-20241022, claude-3-haiku-20240307, claude-3-opus-20240229
|
||||
- **Default Model**: claude-3-5-sonnet-20241022
|
||||
- **Rate Limits**: Based on Anthropic API limits
|
||||
|
||||
#### Ollama (Local Models)
|
||||
- **Setup**: Install Ollama locally and set `OLLAMA_BASE_URL` (default: http://localhost:11434)
|
||||
- **Supported Models**: llama3.2, codellama, mistral, llama2 (any Ollama-compatible model)
|
||||
- **Default Model**: llama3.2
|
||||
- **Rate Limits**: No external API limits (local processing)
|
||||
|
||||
### Testing and Validation
|
||||
- **Frontend tests**: `npm test` (in frontend directory)
|
||||
- **Backend testing**: Use standalone scripts for bulk operations
|
||||
- **API testing**: Use `/docs` endpoint for Swagger UI
|
||||
- **Bulk Processing**: Monitor via `/api/bulk-jobs` and frontend Bulk Jobs tab
|
||||
|
||||
### Security Considerations
|
||||
- **API Keys**: Store NVD and GitHub tokens in environment variables
|
||||
- **PoC Analysis**: Automated analysis of curated PoC repositories (safer than raw GitHub search)
|
||||
- **Rate Limiting**: Built-in rate limiting for external APIs
|
||||
- **Data Validation**: Enhanced validation for bulk data processing
|
||||
- **Audit Trail**: Job tracking for all bulk operations
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
- **Bulk Processing Failures**: Check `/api/bulk-jobs` for detailed error messages
|
||||
- **NVD Data Download Issues**: Verify NVD API key and network connectivity
|
||||
- **nomi-sec API Timeouts**: Built-in retry logic, check network connectivity
|
||||
- **Frontend build errors**: Run `npm install` in frontend directory
|
||||
- **Database schema changes**: Restart backend to auto-create new tables
|
||||
- **Memory issues during bulk processing**: Monitor system resources, consider smaller batch sizes
|
||||
|
||||
### Enhanced Rate Limits
|
||||
- **NVD API**: 5 requests/30s (no key) → 50 requests/30s (with key)
|
||||
- **nomi-sec API**: 1 request/second (built-in rate limiting)
|
||||
- **GitHub API** (fallback): 60 requests/hour (no token) → 5000 requests/hour (with token)
|
||||
|
||||
### Performance Optimization
|
||||
- **Bulk Processing**: Start with recent years (2020+) for faster initial setup
|
||||
- **PoC Sync**: Use smaller batch sizes (50) for better stability
|
||||
- **Rule Generation**: Monitor quality scores to prioritize high-value PoCs
|
||||
- **Database**: Ensure proper indexing on CVE ID and PoC fields
|
||||
|
||||
### Monitoring
|
||||
- **Frontend**: Use Bulk Jobs tab for real-time progress monitoring
|
||||
- **Backend logs**: `docker-compose logs -f backend`
|
||||
- **Job status**: Check `/api/bulk-status` for comprehensive system health
|
||||
- **Database**: Monitor PoC coverage percentage and rule enhancement progress
|
222
backend/celery_config.py
Normal file
222
backend/celery_config.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
Celery configuration for the Auto SIGMA Rule Generator
|
||||
"""
|
||||
import os
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
from kombu import Queue
|
||||
|
||||
# Celery configuration
|
||||
broker_url = os.getenv('CELERY_BROKER_URL', 'redis://redis:6379/0')
|
||||
result_backend = os.getenv('CELERY_RESULT_BACKEND', 'redis://redis:6379/0')
|
||||
|
||||
# Create Celery app
|
||||
celery_app = Celery(
|
||||
'sigma_generator',
|
||||
broker=broker_url,
|
||||
backend=result_backend,
|
||||
include=[
|
||||
'tasks.bulk_tasks',
|
||||
'tasks.sigma_tasks',
|
||||
'tasks.data_sync_tasks',
|
||||
'tasks.maintenance_tasks'
|
||||
]
|
||||
)
|
||||
|
||||
# Celery configuration
|
||||
celery_app.conf.update(
|
||||
# Serialization
|
||||
task_serializer='json',
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
# Timezone
|
||||
timezone='UTC',
|
||||
enable_utc=True,
|
||||
|
||||
# Task tracking
|
||||
task_track_started=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
# Result backend settings
|
||||
result_expires=3600, # Results expire after 1 hour
|
||||
result_backend_transport_options={
|
||||
'master_name': 'mymaster',
|
||||
'visibility_timeout': 3600,
|
||||
},
|
||||
|
||||
# Worker settings
|
||||
worker_prefetch_multiplier=1,
|
||||
task_acks_late=True,
|
||||
worker_max_tasks_per_child=1000,
|
||||
|
||||
# Task routes - different queues for different types of tasks
|
||||
task_routes={
|
||||
'tasks.bulk_tasks.*': {'queue': 'bulk_processing'},
|
||||
'tasks.sigma_tasks.*': {'queue': 'sigma_generation'},
|
||||
'tasks.data_sync_tasks.*': {'queue': 'data_sync'},
|
||||
},
|
||||
|
||||
# Queue definitions
|
||||
task_default_queue='default',
|
||||
task_queues=(
|
||||
Queue('default', routing_key='default'),
|
||||
Queue('bulk_processing', routing_key='bulk_processing'),
|
||||
Queue('sigma_generation', routing_key='sigma_generation'),
|
||||
Queue('data_sync', routing_key='data_sync'),
|
||||
),
|
||||
|
||||
# Retry settings
|
||||
task_default_retry_delay=60, # 1 minute
|
||||
task_max_retries=3,
|
||||
|
||||
# Monitoring
|
||||
worker_send_task_events=True,
|
||||
|
||||
# Optimized Beat schedule for daily workflow
|
||||
# WORKFLOW: NVD incremental -> Exploit syncs -> Reference sync -> SIGMA rules
|
||||
beat_schedule={
|
||||
# STEP 1: NVD Incremental Update - Daily at 2:00 AM
|
||||
# This runs first to get the latest CVE data from NVD
|
||||
'daily-nvd-incremental-update': {
|
||||
'task': 'bulk_tasks.incremental_update_task',
|
||||
'schedule': crontab(minute=0, hour=2), # Daily at 2:00 AM
|
||||
'options': {'queue': 'bulk_processing'},
|
||||
'kwargs': {'batch_size': 100, 'skip_nvd': False, 'skip_nomi_sec': True}
|
||||
},
|
||||
|
||||
# STEP 2: Exploit Data Syncing - Daily starting at 3:00 AM
|
||||
# These run in parallel but start at different times to avoid conflicts
|
||||
|
||||
# CISA KEV Sync - Daily at 3:00 AM (15 minutes after NVD)
|
||||
'daily-cisa-kev-sync': {
|
||||
'task': 'data_sync_tasks.sync_cisa_kev',
|
||||
'schedule': crontab(minute=0, hour=3), # Daily at 3:00 AM
|
||||
'options': {'queue': 'data_sync'},
|
||||
'kwargs': {'batch_size': 100}
|
||||
},
|
||||
|
||||
# Nomi-sec PoC Sync - Daily at 3:15 AM
|
||||
'daily-nomi-sec-sync': {
|
||||
'task': 'data_sync_tasks.sync_nomi_sec',
|
||||
'schedule': crontab(minute=15, hour=3), # Daily at 3:15 AM
|
||||
'options': {'queue': 'data_sync'},
|
||||
'kwargs': {'batch_size': 100}
|
||||
},
|
||||
|
||||
# GitHub PoC Sync - Daily at 3:30 AM
|
||||
'daily-github-poc-sync': {
|
||||
'task': 'data_sync_tasks.sync_github_poc',
|
||||
'schedule': crontab(minute=30, hour=3), # Daily at 3:30 AM
|
||||
'options': {'queue': 'data_sync'},
|
||||
'kwargs': {'batch_size': 50}
|
||||
},
|
||||
|
||||
# ExploitDB Sync - Daily at 3:45 AM
|
||||
'daily-exploitdb-sync': {
|
||||
'task': 'data_sync_tasks.sync_exploitdb',
|
||||
'schedule': crontab(minute=45, hour=3), # Daily at 3:45 AM
|
||||
'options': {'queue': 'data_sync'},
|
||||
'kwargs': {'batch_size': 30}
|
||||
},
|
||||
|
||||
# CVE2CAPEC MITRE ATT&CK Mapping Sync - Daily at 4:00 AM
|
||||
'daily-cve2capec-sync': {
|
||||
'task': 'data_sync_tasks.sync_cve2capec',
|
||||
'schedule': crontab(minute=0, hour=4), # Daily at 4:00 AM
|
||||
'options': {'queue': 'data_sync'},
|
||||
'kwargs': {'force_refresh': False} # Only refresh if cache is stale
|
||||
},
|
||||
|
||||
# ExploitDB Index Rebuild - Daily at 4:15 AM
|
||||
'daily-exploitdb-index-build': {
|
||||
'task': 'data_sync_tasks.build_exploitdb_index',
|
||||
'schedule': crontab(minute=15, hour=4), # Daily at 4:15 AM
|
||||
'options': {'queue': 'data_sync'}
|
||||
},
|
||||
|
||||
# STEP 3: Reference Content Sync - Daily at 5:00 AM
|
||||
# This is the longest-running task, starts after exploit syncs have time to complete
|
||||
'daily-reference-content-sync': {
|
||||
'task': 'data_sync_tasks.sync_reference_content',
|
||||
'schedule': crontab(minute=0, hour=5), # Daily at 5:00 AM
|
||||
'options': {'queue': 'data_sync'},
|
||||
'kwargs': {'batch_size': 30, 'max_cves': 200, 'force_resync': False}
|
||||
},
|
||||
|
||||
# STEP 4: SIGMA Rule Generation - Daily at 8:00 AM
|
||||
# This runs LAST after all other daily data sync jobs
|
||||
'daily-sigma-rule-generation': {
|
||||
'task': 'bulk_tasks.generate_enhanced_sigma_rules',
|
||||
'schedule': crontab(minute=0, hour=8), # Daily at 8:00 AM
|
||||
'options': {'queue': 'sigma_generation'}
|
||||
},
|
||||
|
||||
# LLM-Enhanced SIGMA Rule Generation - Daily at 9:00 AM
|
||||
# Additional LLM-based rule generation after standard rules
|
||||
'daily-llm-sigma-generation': {
|
||||
'task': 'sigma_tasks.generate_enhanced_rules',
|
||||
'schedule': crontab(minute=0, hour=9), # Daily at 9:00 AM
|
||||
'options': {'queue': 'sigma_generation'},
|
||||
'kwargs': {'cve_ids': None} # Process all CVEs with PoCs
|
||||
},
|
||||
|
||||
# MAINTENANCE TASKS
|
||||
|
||||
# Database Cleanup - Weekly on Sunday at 1:00 AM (before daily workflow)
|
||||
'weekly-database-cleanup': {
|
||||
'task': 'tasks.maintenance_tasks.database_cleanup_comprehensive',
|
||||
'schedule': crontab(minute=0, hour=1, day_of_week=0), # Sunday at 1:00 AM
|
||||
'options': {'queue': 'default'},
|
||||
'kwargs': {'days_to_keep': 30, 'cleanup_failed_jobs': True, 'cleanup_logs': True}
|
||||
},
|
||||
|
||||
# Health Check - Every 15 minutes
|
||||
'health-check-detailed': {
|
||||
'task': 'tasks.maintenance_tasks.health_check_detailed',
|
||||
'schedule': crontab(minute='*/15'), # Every 15 minutes
|
||||
'options': {'queue': 'default'}
|
||||
},
|
||||
|
||||
# Celery result cleanup - Daily at 1:30 AM
|
||||
'daily-cleanup-old-results': {
|
||||
'task': 'tasks.maintenance_tasks.cleanup_old_results',
|
||||
'schedule': crontab(minute=30, hour=1), # Daily at 1:30 AM
|
||||
'options': {'queue': 'default'}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
celery_app.conf.update(
|
||||
worker_log_format='[%(asctime)s: %(levelname)s/%(processName)s] %(message)s',
|
||||
worker_task_log_format='[%(asctime)s: %(levelname)s/%(processName)s][%(task_name)s(%(task_id)s)] %(message)s',
|
||||
)
|
||||
|
||||
# Database session configuration for tasks
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'postgresql://cve_user:cve_password@db:5432/cve_sigma_db')
|
||||
|
||||
# Create engine and session factory
|
||||
engine = create_engine(DATABASE_URL)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
def get_db_session():
|
||||
"""Get database session for tasks"""
|
||||
return SessionLocal()
|
||||
|
||||
# Import all task modules to register them
|
||||
def register_tasks():
|
||||
"""Register all task modules"""
|
||||
try:
|
||||
from tasks import bulk_tasks, sigma_tasks, data_sync_tasks, maintenance_tasks
|
||||
print("All task modules registered successfully")
|
||||
except ImportError as e:
|
||||
print(f"Warning: Could not import some task modules: {e}")
|
||||
|
||||
# Auto-register tasks when module is imported
|
||||
if __name__ != "__main__":
|
||||
register_tasks()
|
|
@ -15,18 +15,20 @@ logger = logging.getLogger(__name__)
|
|||
class CVE2CAPECClient:
|
||||
"""Client for accessing CVE to MITRE ATT&CK technique mappings."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, lazy_load: bool = True):
|
||||
self.base_url = "https://raw.githubusercontent.com/Galeax/CVE2CAPEC/main"
|
||||
self.cache_file = "/tmp/cve2capec_cache.json"
|
||||
self.cache_expiry_hours = 24 # Cache for 24 hours
|
||||
self.cve_mappings = {}
|
||||
self.technique_names = {} # Map technique IDs to names
|
||||
self._data_loaded = False
|
||||
|
||||
# Load cached data if available
|
||||
self._load_cache()
|
||||
|
||||
# Load MITRE ATT&CK technique names
|
||||
# Load MITRE ATT&CK technique names (lightweight)
|
||||
self._load_technique_names()
|
||||
|
||||
# Only load cached data if not lazy loading
|
||||
if not lazy_load:
|
||||
self._load_cache()
|
||||
|
||||
def _load_cache(self):
|
||||
"""Load cached CVE mappings if they exist and are fresh."""
|
||||
|
@ -39,15 +41,18 @@ class CVE2CAPECClient:
|
|||
cache_time = datetime.fromisoformat(cache_data.get('timestamp', '2000-01-01'))
|
||||
if datetime.now() - cache_time < timedelta(hours=self.cache_expiry_hours):
|
||||
self.cve_mappings = cache_data.get('mappings', {})
|
||||
self._data_loaded = True
|
||||
logger.info(f"Loaded {len(self.cve_mappings)} CVE mappings from cache")
|
||||
return
|
||||
|
||||
# Cache is stale or doesn't exist, fetch fresh data
|
||||
self._fetch_fresh_data()
|
||||
self._data_loaded = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CVE2CAPEC cache: {e}")
|
||||
self._fetch_fresh_data()
|
||||
self._data_loaded = True
|
||||
|
||||
def _fetch_fresh_data(self):
|
||||
"""Fetch fresh CVE mappings from the repository."""
|
||||
|
@ -133,6 +138,12 @@ class CVE2CAPECClient:
|
|||
# Continue with empty mappings if fetch fails
|
||||
self.cve_mappings = {}
|
||||
|
||||
def _ensure_data_loaded(self):
|
||||
"""Ensure CVE mappings are loaded, loading from cache if needed."""
|
||||
if not self._data_loaded:
|
||||
logger.info("CVE2CAPEC data not loaded, loading from cache...")
|
||||
self._load_cache()
|
||||
|
||||
def _load_technique_names(self):
|
||||
"""Load MITRE ATT&CK technique names for better rule descriptions."""
|
||||
# Common MITRE ATT&CK techniques and their names
|
||||
|
@ -350,6 +361,7 @@ class CVE2CAPECClient:
|
|||
def get_mitre_techniques_for_cve(self, cve_id: str) -> List[str]:
|
||||
"""Get MITRE ATT&CK techniques for a given CVE ID."""
|
||||
try:
|
||||
self._ensure_data_loaded()
|
||||
cve_data = self.cve_mappings.get(cve_id, {})
|
||||
techniques = cve_data.get('TECHNIQUES', [])
|
||||
|
||||
|
@ -374,6 +386,7 @@ class CVE2CAPECClient:
|
|||
def get_cwe_for_cve(self, cve_id: str) -> List[str]:
|
||||
"""Get CWE codes for a given CVE ID."""
|
||||
try:
|
||||
self._ensure_data_loaded()
|
||||
cve_data = self.cve_mappings.get(cve_id, {})
|
||||
cwes = cve_data.get('CWE', [])
|
||||
|
||||
|
@ -392,6 +405,7 @@ class CVE2CAPECClient:
|
|||
def get_capec_for_cve(self, cve_id: str) -> List[str]:
|
||||
"""Get CAPEC codes for a given CVE ID."""
|
||||
try:
|
||||
self._ensure_data_loaded()
|
||||
cve_data = self.cve_mappings.get(cve_id, {})
|
||||
capecs = cve_data.get('CAPEC', [])
|
||||
|
||||
|
@ -430,6 +444,7 @@ class CVE2CAPECClient:
|
|||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get statistics about the CVE2CAPEC dataset."""
|
||||
self._ensure_data_loaded()
|
||||
total_cves = len(self.cve_mappings)
|
||||
cves_with_techniques = len([cve for cve, data in self.cve_mappings.items()
|
||||
if data.get('TECHNIQUES')])
|
||||
|
|
|
@ -7,7 +7,7 @@ import os
|
|||
import re
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
|
@ -16,10 +16,15 @@ from pathlib import Path
|
|||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global cache for file index to avoid rebuilding on every request
|
||||
_global_file_index = {}
|
||||
_index_last_built = None
|
||||
_index_cache_hours = 24 # Cache for 24 hours
|
||||
|
||||
class ExploitDBLocalClient:
|
||||
"""Client for interfacing with local ExploitDB mirror filesystem"""
|
||||
|
||||
def __init__(self, db_session: Session):
|
||||
def __init__(self, db_session: Session, lazy_load: bool = True):
|
||||
self.db_session = db_session
|
||||
|
||||
# Path to the local exploit-db-mirror submodule (in container: /app/exploit-db-mirror)
|
||||
|
@ -32,13 +37,38 @@ class ExploitDBLocalClient:
|
|||
# Cache for file searches
|
||||
self.file_cache = {}
|
||||
|
||||
# Build file index on initialization
|
||||
self._build_file_index()
|
||||
# Use global cache and only build if needed
|
||||
global _global_file_index, _index_last_built
|
||||
self.file_index = _global_file_index
|
||||
|
||||
# Build file index only if not lazy loading or if cache is stale
|
||||
if not lazy_load:
|
||||
self._ensure_index_built()
|
||||
|
||||
def _ensure_index_built(self):
|
||||
"""Ensure the file index is built and fresh"""
|
||||
global _global_file_index, _index_last_built
|
||||
|
||||
# Check if index needs to be rebuilt
|
||||
needs_rebuild = (
|
||||
not _global_file_index or # No index exists
|
||||
_index_last_built is None or # Never built
|
||||
datetime.now() - _index_last_built > timedelta(hours=_index_cache_hours) # Cache expired
|
||||
)
|
||||
|
||||
if needs_rebuild:
|
||||
self._build_file_index()
|
||||
else:
|
||||
# Use cached index
|
||||
self.file_index = _global_file_index
|
||||
logger.debug(f"Using cached ExploitDB index with {len(self.file_index)} exploits")
|
||||
|
||||
def _build_file_index(self):
|
||||
"""Build an index of exploit ID to file path for fast lookups"""
|
||||
global _global_file_index, _index_last_built
|
||||
|
||||
logger.info("Building ExploitDB file index...")
|
||||
self.file_index = {}
|
||||
temp_index = {}
|
||||
|
||||
if not self.exploits_path.exists():
|
||||
logger.error(f"ExploitDB path not found: {self.exploits_path}")
|
||||
|
@ -55,7 +85,7 @@ class ExploitDBLocalClient:
|
|||
file_path = Path(root) / file
|
||||
|
||||
# Store in index
|
||||
self.file_index[exploit_id] = {
|
||||
temp_index[exploit_id] = {
|
||||
'path': file_path,
|
||||
'filename': file,
|
||||
'extension': file_extension,
|
||||
|
@ -63,6 +93,11 @@ class ExploitDBLocalClient:
|
|||
'subcategory': self._extract_subcategory_from_path(file_path)
|
||||
}
|
||||
|
||||
# Update global cache
|
||||
_global_file_index = temp_index
|
||||
_index_last_built = datetime.now()
|
||||
self.file_index = _global_file_index
|
||||
|
||||
logger.info(f"Built index with {len(self.file_index)} exploits")
|
||||
|
||||
def _extract_category_from_path(self, file_path: Path) -> str:
|
||||
|
@ -104,6 +139,7 @@ class ExploitDBLocalClient:
|
|||
|
||||
def get_exploit_details(self, exploit_id: str) -> Optional[dict]:
|
||||
"""Get exploit details from local filesystem"""
|
||||
self._ensure_index_built()
|
||||
if exploit_id not in self.file_index:
|
||||
logger.debug(f"Exploit {exploit_id} not found in local index")
|
||||
return None
|
||||
|
@ -699,6 +735,9 @@ class ExploitDBLocalClient:
|
|||
from main import CVE
|
||||
from sqlalchemy import text
|
||||
|
||||
# Ensure index is built for stats calculation
|
||||
self._ensure_index_built()
|
||||
|
||||
# Count CVEs with ExploitDB references
|
||||
total_cves = self.db_session.query(CVE).count()
|
||||
|
||||
|
|
171
backend/initial_setup.py
Normal file
171
backend/initial_setup.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Initial setup script that runs once on first boot to populate the database.
|
||||
This script checks if initial data seeding is needed and triggers it via Celery.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
# Add the current directory to path so we can import our modules
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'postgresql://cve_user:cve_password@db:5432/cve_sigma_db')
|
||||
|
||||
def wait_for_database(max_retries: int = 30, delay: int = 5) -> bool:
|
||||
"""Wait for database to be ready"""
|
||||
logger.info("Waiting for database to be ready...")
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
logger.info("✅ Database is ready!")
|
||||
return True
|
||||
except OperationalError as e:
|
||||
logger.info(f"Attempt {attempt + 1}/{max_retries}: Database not ready yet ({e})")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error connecting to database: {e}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(delay)
|
||||
|
||||
logger.error("❌ Database failed to become ready")
|
||||
return False
|
||||
|
||||
def check_initial_setup_needed() -> bool:
|
||||
"""Check if initial setup is needed by examining the database state"""
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
with SessionLocal() as session:
|
||||
# Check if we have any CVEs in the database
|
||||
result = session.execute(text("SELECT COUNT(*) FROM cves")).fetchone()
|
||||
cve_count = result[0] if result else 0
|
||||
|
||||
logger.info(f"Current CVE count in database: {cve_count}")
|
||||
|
||||
# Check if we have any bulk processing jobs that completed successfully
|
||||
bulk_jobs_result = session.execute(text("""
|
||||
SELECT COUNT(*) FROM bulk_processing_jobs
|
||||
WHERE job_type = 'nvd_bulk_seed'
|
||||
AND status = 'completed'
|
||||
AND created_at > NOW() - INTERVAL '30 days'
|
||||
""")).fetchone()
|
||||
|
||||
recent_bulk_jobs = bulk_jobs_result[0] if bulk_jobs_result else 0
|
||||
|
||||
logger.info(f"Recent successful bulk seed jobs: {recent_bulk_jobs}")
|
||||
|
||||
# Initial setup needed if:
|
||||
# 1. Very few CVEs (less than 1000) AND
|
||||
# 2. No recent successful bulk seed jobs
|
||||
initial_setup_needed = cve_count < 1000 and recent_bulk_jobs == 0
|
||||
|
||||
if initial_setup_needed:
|
||||
logger.info("🔄 Initial setup is needed - will trigger full NVD sync")
|
||||
else:
|
||||
logger.info("✅ Initial setup already completed - database has sufficient data")
|
||||
|
||||
return initial_setup_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking initial setup status: {e}")
|
||||
# If we can't check, assume setup is needed
|
||||
return True
|
||||
|
||||
def trigger_initial_bulk_seed():
|
||||
"""Trigger initial bulk seed via Celery"""
|
||||
try:
|
||||
# Import here to avoid circular dependencies
|
||||
from celery_config import celery_app
|
||||
from tasks.bulk_tasks import full_bulk_seed_task
|
||||
|
||||
logger.info("🚀 Triggering initial full NVD bulk seed...")
|
||||
|
||||
# Start a comprehensive bulk seed job
|
||||
# Start from 2020 for faster initial setup, can be adjusted
|
||||
task_result = full_bulk_seed_task.delay(
|
||||
start_year=2020, # Start from 2020 for faster initial setup
|
||||
end_year=None, # Current year
|
||||
skip_nvd=False,
|
||||
skip_nomi_sec=True, # Skip nomi-sec initially, will be done daily
|
||||
skip_exploitdb=True, # Skip exploitdb initially, will be done daily
|
||||
skip_cisa_kev=True # Skip CISA KEV initially, will be done daily
|
||||
)
|
||||
|
||||
logger.info(f"✅ Initial bulk seed task started with ID: {task_result.id}")
|
||||
logger.info(f"Monitor progress at: http://localhost:5555/task/{task_result.id}")
|
||||
|
||||
return task_result.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to trigger initial bulk seed: {e}")
|
||||
return None
|
||||
|
||||
def create_initial_setup_marker():
|
||||
"""Create a marker to indicate initial setup was attempted"""
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
with SessionLocal() as session:
|
||||
# Insert a marker record
|
||||
session.execute(text("""
|
||||
INSERT INTO bulk_processing_jobs (job_type, status, job_metadata, created_at, started_at)
|
||||
VALUES ('initial_setup_marker', 'completed', '{"purpose": "initial_setup_marker"}', NOW(), NOW())
|
||||
ON CONFLICT DO NOTHING
|
||||
"""))
|
||||
session.commit()
|
||||
|
||||
logger.info("✅ Created initial setup marker")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating initial setup marker: {e}")
|
||||
|
||||
def main():
|
||||
"""Main initial setup function"""
|
||||
logger.info("🚀 Starting initial setup check...")
|
||||
|
||||
# Step 1: Wait for database
|
||||
if not wait_for_database():
|
||||
logger.error("❌ Initial setup failed: Database not available")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Check if initial setup is needed
|
||||
if not check_initial_setup_needed():
|
||||
logger.info("🎉 Initial setup not needed - database already populated")
|
||||
sys.exit(0)
|
||||
|
||||
# Step 3: Wait for Celery to be ready
|
||||
logger.info("Waiting for Celery workers to be ready...")
|
||||
time.sleep(10) # Give Celery workers time to start
|
||||
|
||||
# Step 4: Trigger initial bulk seed
|
||||
task_id = trigger_initial_bulk_seed()
|
||||
|
||||
if task_id:
|
||||
# Step 5: Create marker
|
||||
create_initial_setup_marker()
|
||||
|
||||
logger.info("🎉 Initial setup triggered successfully!")
|
||||
logger.info(f"Task ID: {task_id}")
|
||||
logger.info("The system will begin daily scheduled tasks once initial setup completes.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Initial setup failed")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
LangChain-based LLM client for enhanced SIGMA rule generation.
|
||||
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
|
||||
Supports multiple LLM providers: OpenAI, Anthropic, and Ollama.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
|
@ -31,7 +31,7 @@ class LLMClient:
|
|||
'default_model': 'claude-3-5-sonnet-20241022'
|
||||
},
|
||||
'ollama': {
|
||||
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
|
||||
'models': ['llama3.2', 'codellama', 'mistral', 'llama2', 'sigma-llama-finetuned'],
|
||||
'env_key': 'OLLAMA_BASE_URL',
|
||||
'default_model': 'llama3.2'
|
||||
}
|
||||
|
@ -110,6 +110,8 @@ class LLMClient:
|
|||
base_url=base_url,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
|
||||
|
||||
if self.llm:
|
||||
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
|
||||
|
@ -128,9 +130,9 @@ class LLMClient:
|
|||
"""Get information about the current provider and configuration."""
|
||||
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
||||
|
||||
# For Ollama, get actually available models
|
||||
# For Ollama and fine-tuned, get actually available models
|
||||
available_models = provider_info.get('models', [])
|
||||
if self.provider == 'ollama':
|
||||
if self.provider in ['ollama', 'finetuned']:
|
||||
ollama_models = self._get_ollama_available_models()
|
||||
if ollama_models:
|
||||
available_models = ollama_models
|
||||
|
@ -186,9 +188,22 @@ class LLMClient:
|
|||
logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...")
|
||||
logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...")
|
||||
|
||||
# Generate the response
|
||||
# Generate the response with memory error handling
|
||||
logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}")
|
||||
response = await chain.ainvoke(input_data)
|
||||
try:
|
||||
response = await chain.ainvoke(input_data)
|
||||
except Exception as llm_error:
|
||||
# Handle memory issues or model loading failures
|
||||
error_msg = str(llm_error).lower()
|
||||
if any(keyword in error_msg for keyword in ["memory", "out of memory", "too large", "available", "model request"]):
|
||||
logger.error(f"LLM memory error for {cve_id}: {llm_error}")
|
||||
|
||||
# For memory errors, we don't have specific fallback logic currently
|
||||
logger.error(f"No fallback available for provider {self.provider}")
|
||||
return None
|
||||
else:
|
||||
# Re-raise non-memory errors
|
||||
raise llm_error
|
||||
|
||||
# Debug: Log raw LLM response
|
||||
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
||||
|
@ -1351,13 +1366,15 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|||
env_key = provider_info.get('env_key', '')
|
||||
api_key_configured = bool(os.getenv(env_key))
|
||||
|
||||
available = api_key_configured or provider_name == 'ollama'
|
||||
|
||||
providers.append({
|
||||
'name': provider_name,
|
||||
'models': provider_info.get('models', []),
|
||||
'default_model': provider_info.get('default_model', ''),
|
||||
'env_key': env_key,
|
||||
'api_key_configured': api_key_configured,
|
||||
'available': api_key_configured or provider_name == 'ollama'
|
||||
'available': available
|
||||
})
|
||||
|
||||
return providers
|
||||
|
|
566
backend/main.py
566
backend/main.py
|
@ -820,36 +820,14 @@ async def lifespan(app: FastAPI):
|
|||
finally:
|
||||
db.close()
|
||||
|
||||
# Initialize and start the job scheduler
|
||||
try:
|
||||
from job_scheduler import initialize_scheduler
|
||||
from job_executors import register_all_executors
|
||||
|
||||
# Initialize scheduler
|
||||
scheduler = initialize_scheduler()
|
||||
scheduler.set_db_session_factory(SessionLocal)
|
||||
|
||||
# Register all job executors
|
||||
register_all_executors(scheduler)
|
||||
|
||||
# Start the scheduler
|
||||
scheduler.start()
|
||||
|
||||
logger.info("Job scheduler initialized and started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing job scheduler: {e}")
|
||||
# Note: Job scheduling is now handled by Celery Beat
|
||||
# All scheduled tasks are defined in celery_config.py
|
||||
logger.info("Application startup complete - scheduled tasks handled by Celery Beat")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
scheduler = get_scheduler()
|
||||
scheduler.stop()
|
||||
logger.info("Job scheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping job scheduler: {e}")
|
||||
logger.info("Application shutdown complete")
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
|
||||
|
@ -862,6 +840,16 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include Celery job management routes
|
||||
try:
|
||||
from routers.celery_jobs import router as celery_router
|
||||
app.include_router(celery_router, prefix="/api")
|
||||
logger.info("Celery job routes loaded successfully")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Celery job routes not available: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading Celery job routes: {e}")
|
||||
|
||||
@app.get("/api/cves", response_model=List[CVEResponse])
|
||||
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
||||
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
|
||||
|
@ -1074,213 +1062,113 @@ async def get_stats(db: Session = Depends(get_db)):
|
|||
|
||||
# New bulk processing endpoints
|
||||
@app.post("/api/bulk-seed")
|
||||
async def start_bulk_seed(background_tasks: BackgroundTasks,
|
||||
request: BulkSeedRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Start bulk seeding process"""
|
||||
|
||||
async def bulk_seed_task():
|
||||
try:
|
||||
from bulk_seeder import BulkSeeder
|
||||
seeder = BulkSeeder(db)
|
||||
result = await seeder.full_bulk_seed(
|
||||
start_year=request.start_year,
|
||||
end_year=request.end_year,
|
||||
skip_nvd=request.skip_nvd,
|
||||
skip_nomi_sec=request.skip_nomi_sec
|
||||
)
|
||||
logger.info(f"Bulk seed completed: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk seed failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
background_tasks.add_task(bulk_seed_task)
|
||||
|
||||
return {
|
||||
"message": "Bulk seeding process started",
|
||||
"status": "started",
|
||||
"start_year": request.start_year,
|
||||
"end_year": request.end_year or datetime.now().year,
|
||||
"skip_nvd": request.skip_nvd,
|
||||
"skip_nomi_sec": request.skip_nomi_sec
|
||||
}
|
||||
async def start_bulk_seed(request: BulkSeedRequest):
|
||||
"""Start bulk seeding process - redirects to async endpoint"""
|
||||
try:
|
||||
from routers.celery_jobs import start_bulk_seed as async_bulk_seed
|
||||
from routers.celery_jobs import BulkSeedRequest as CeleryBulkSeedRequest
|
||||
|
||||
# Convert request to Celery format
|
||||
celery_request = CeleryBulkSeedRequest(
|
||||
start_year=request.start_year,
|
||||
end_year=request.end_year,
|
||||
skip_nvd=request.skip_nvd,
|
||||
skip_nomi_sec=request.skip_nomi_sec,
|
||||
skip_exploitdb=getattr(request, 'skip_exploitdb', False),
|
||||
skip_cisa_kev=getattr(request, 'skip_cisa_kev', False)
|
||||
)
|
||||
|
||||
# Call async endpoint
|
||||
result = await async_bulk_seed(celery_request)
|
||||
|
||||
return {
|
||||
"message": "Bulk seeding process started (async)",
|
||||
"status": "started",
|
||||
"task_id": result.task_id,
|
||||
"start_year": request.start_year,
|
||||
"end_year": request.end_year or datetime.now().year,
|
||||
"skip_nvd": request.skip_nvd,
|
||||
"skip_nomi_sec": request.skip_nomi_sec,
|
||||
"async_endpoint": f"/api/task-status/{result.task_id}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting bulk seed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start bulk seed: {e}")
|
||||
|
||||
@app.post("/api/incremental-update")
|
||||
async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
"""Start incremental update process"""
|
||||
|
||||
async def incremental_update_task():
|
||||
try:
|
||||
from bulk_seeder import BulkSeeder
|
||||
seeder = BulkSeeder(db)
|
||||
result = await seeder.incremental_update()
|
||||
logger.info(f"Incremental update completed: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"Incremental update failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
background_tasks.add_task(incremental_update_task)
|
||||
|
||||
return {
|
||||
"message": "Incremental update process started",
|
||||
"status": "started"
|
||||
}
|
||||
async def start_incremental_update():
|
||||
"""Start incremental update process - redirects to async endpoint"""
|
||||
try:
|
||||
from routers.celery_jobs import start_incremental_update as async_incremental_update
|
||||
|
||||
# Call async endpoint
|
||||
result = await async_incremental_update()
|
||||
|
||||
return {
|
||||
"message": "Incremental update process started (async)",
|
||||
"status": "started",
|
||||
"task_id": result.task_id,
|
||||
"async_endpoint": f"/api/task-status/{result.task_id}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting incremental update: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start incremental update: {e}")
|
||||
|
||||
@app.post("/api/sync-nomi-sec")
|
||||
async def sync_nomi_sec(background_tasks: BackgroundTasks,
|
||||
request: NomiSecSyncRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Synchronize nomi-sec PoC data"""
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='nomi_sec_sync',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size
|
||||
async def sync_nomi_sec(request: NomiSecSyncRequest):
|
||||
"""Synchronize nomi-sec PoC data - redirects to async endpoint"""
|
||||
try:
|
||||
from routers.celery_jobs import start_nomi_sec_sync as async_nomi_sec_sync
|
||||
from routers.celery_jobs import DataSyncRequest as CeleryDataSyncRequest
|
||||
|
||||
# Convert request to Celery format
|
||||
celery_request = CeleryDataSyncRequest(
|
||||
batch_size=request.batch_size
|
||||
)
|
||||
|
||||
# Call async endpoint
|
||||
result = await async_nomi_sec_sync(celery_request)
|
||||
|
||||
return {
|
||||
"message": f"Nomi-sec sync started (async)" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"task_id": result.task_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size,
|
||||
"async_endpoint": f"/api/task-status/{result.task_id}"
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
running_jobs[job_id] = job
|
||||
job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
try:
|
||||
job.status = 'running'
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
from nomi_sec_client import NomiSecClient
|
||||
client = NomiSecClient(db)
|
||||
|
||||
if request.cve_id:
|
||||
# Sync specific CVE
|
||||
if job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled before starting")
|
||||
return
|
||||
|
||||
result = await client.sync_cve_pocs(request.cve_id)
|
||||
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
|
||||
else:
|
||||
# Sync all CVEs with cancellation support
|
||||
result = await client.bulk_sync_all_cves(
|
||||
batch_size=request.batch_size,
|
||||
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
||||
)
|
||||
logger.info(f"Nomi-sec bulk sync completed: {result}")
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if not job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.error(f"Nomi-sec sync failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up tracking
|
||||
running_jobs.pop(job_id, None)
|
||||
job_cancellation_flags.pop(job_id, None)
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting nomi-sec sync: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start nomi-sec sync: {e}")
|
||||
|
||||
@app.post("/api/sync-github-pocs")
|
||||
async def sync_github_pocs(background_tasks: BackgroundTasks,
|
||||
request: GitHubPoCSyncRequest,
|
||||
async def sync_github_pocs(request: GitHubPoCSyncRequest,
|
||||
db: Session = Depends(get_db)):
|
||||
"""Synchronize GitHub PoC data"""
|
||||
|
||||
# Create job record
|
||||
job = BulkProcessingJob(
|
||||
job_type='github_poc_sync',
|
||||
status='pending',
|
||||
job_metadata={
|
||||
'cve_id': request.cve_id,
|
||||
'batch_size': request.batch_size
|
||||
"""Synchronize GitHub PoC data using Celery task"""
|
||||
try:
|
||||
from celery_config import celery_app
|
||||
from tasks.data_sync_tasks import sync_github_poc_task
|
||||
|
||||
# Launch Celery task
|
||||
if request.cve_id:
|
||||
# For specific CVE sync, we'll still use the general task
|
||||
task_result = sync_github_poc_task.delay(batch_size=request.batch_size)
|
||||
else:
|
||||
# For bulk sync
|
||||
task_result = sync_github_poc_task.delay(batch_size=request.batch_size)
|
||||
|
||||
return {
|
||||
"message": f"GitHub PoC sync started via Celery" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"task_id": task_result.id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size,
|
||||
"monitor_url": "http://localhost:5555/task/" + task_result.id
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
job_id = str(job.id)
|
||||
running_jobs[job_id] = job
|
||||
job_cancellation_flags[job_id] = False
|
||||
|
||||
async def sync_task():
|
||||
try:
|
||||
job.status = 'running'
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
client = GitHubPoCClient(db)
|
||||
|
||||
if request.cve_id:
|
||||
# Sync specific CVE
|
||||
if job_cancellation_flags.get(job_id, False):
|
||||
logger.info(f"Job {job_id} cancelled before starting")
|
||||
return
|
||||
|
||||
result = await client.sync_cve_pocs(request.cve_id)
|
||||
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
|
||||
else:
|
||||
# Sync all CVEs with cancellation support
|
||||
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
|
||||
logger.info(f"GitHub PoC bulk sync completed: {result}")
|
||||
|
||||
# Update job status if not cancelled
|
||||
if not job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if not job_cancellation_flags.get(job_id, False):
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.error(f"GitHub PoC sync failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up tracking
|
||||
running_jobs.pop(job_id, None)
|
||||
job_cancellation_flags.pop(job_id, None)
|
||||
|
||||
background_tasks.add_task(sync_task)
|
||||
|
||||
return {
|
||||
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
||||
"status": "started",
|
||||
"job_id": job_id,
|
||||
"cve_id": request.cve_id,
|
||||
"batch_size": request.batch_size
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting GitHub PoC sync via Celery: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start GitHub PoC sync: {e}")
|
||||
|
||||
@app.post("/api/sync-exploitdb")
|
||||
async def sync_exploitdb(background_tasks: BackgroundTasks,
|
||||
|
@ -1850,6 +1738,55 @@ async def get_poc_stats(db: Session = Depends(get_db)):
|
|||
logger.error(f"Error getting PoC stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
@app.post("/api/sync-cve2capec")
|
||||
async def sync_cve2capec(force_refresh: bool = False):
|
||||
"""Synchronize CVE2CAPEC MITRE ATT&CK mappings using Celery task"""
|
||||
try:
|
||||
from celery_config import celery_app
|
||||
from tasks.data_sync_tasks import sync_cve2capec_task
|
||||
|
||||
# Launch Celery task
|
||||
task_result = sync_cve2capec_task.delay(force_refresh=force_refresh)
|
||||
|
||||
return {
|
||||
"message": "CVE2CAPEC MITRE ATT&CK mapping sync started via Celery",
|
||||
"status": "started",
|
||||
"task_id": task_result.id,
|
||||
"force_refresh": force_refresh,
|
||||
"monitor_url": f"http://localhost:5555/task/{task_result.id}"
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import Celery components: {e}")
|
||||
raise HTTPException(status_code=500, detail="Celery not properly configured")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting CVE2CAPEC sync: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start CVE2CAPEC sync: {e}")
|
||||
|
||||
@app.post("/api/build-exploitdb-index")
|
||||
async def build_exploitdb_index():
|
||||
"""Build/rebuild ExploitDB file index using Celery task"""
|
||||
try:
|
||||
from celery_config import celery_app
|
||||
from tasks.data_sync_tasks import build_exploitdb_index_task
|
||||
|
||||
# Launch Celery task
|
||||
task_result = build_exploitdb_index_task.delay()
|
||||
|
||||
return {
|
||||
"message": "ExploitDB file index build started via Celery",
|
||||
"status": "started",
|
||||
"task_id": task_result.id,
|
||||
"monitor_url": f"http://localhost:5555/task/{task_result.id}"
|
||||
}
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import Celery components: {e}")
|
||||
raise HTTPException(status_code=500, detail="Celery not properly configured")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting ExploitDB index build: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start ExploitDB index build: {e}")
|
||||
|
||||
@app.get("/api/cve2capec-stats")
|
||||
async def get_cve2capec_stats():
|
||||
"""Get CVE2CAPEC MITRE ATT&CK mapping statistics"""
|
||||
|
@ -2201,172 +2138,23 @@ async def get_ollama_models():
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# ============================================================================
|
||||
# SCHEDULER ENDPOINTS
|
||||
# NOTE: SCHEDULER ENDPOINTS REMOVED
|
||||
# ============================================================================
|
||||
#
|
||||
# Job scheduling is now handled by Celery Beat with periodic tasks.
|
||||
# All scheduled tasks are defined in celery_config.py beat_schedule.
|
||||
#
|
||||
# To manage scheduled tasks:
|
||||
# - View tasks: Use Celery monitoring tools (Flower, Celery events)
|
||||
# - Control tasks: Use Celery control commands or through Celery job management endpoints
|
||||
# - Schedule changes: Update celery_config.py and restart Celery Beat
|
||||
#
|
||||
# Available Celery job management endpoints:
|
||||
# - GET /api/celery/tasks - List all active tasks
|
||||
# - POST /api/celery/tasks/{task_id}/revoke - Cancel a running task
|
||||
# - GET /api/celery/workers - View worker status
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
class SchedulerControlRequest(BaseModel):
|
||||
action: str # 'start', 'stop', 'restart'
|
||||
|
||||
class JobControlRequest(BaseModel):
|
||||
job_name: str
|
||||
action: str # 'enable', 'disable', 'trigger'
|
||||
|
||||
class UpdateScheduleRequest(BaseModel):
|
||||
job_name: str
|
||||
schedule: str # Cron expression
|
||||
|
||||
@app.get("/api/scheduler/status")
|
||||
async def get_scheduler_status():
|
||||
"""Get scheduler status and job information"""
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
|
||||
scheduler = get_scheduler()
|
||||
status = scheduler.get_job_status()
|
||||
|
||||
return {
|
||||
"scheduler_status": status,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting scheduler status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/scheduler/control")
|
||||
async def control_scheduler(request: SchedulerControlRequest):
|
||||
"""Control scheduler (start/stop/restart)"""
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
|
||||
scheduler = get_scheduler()
|
||||
|
||||
if request.action == 'start':
|
||||
scheduler.start()
|
||||
message = "Scheduler started"
|
||||
elif request.action == 'stop':
|
||||
scheduler.stop()
|
||||
message = "Scheduler stopped"
|
||||
elif request.action == 'restart':
|
||||
scheduler.stop()
|
||||
scheduler.start()
|
||||
message = "Scheduler restarted"
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"action": request.action,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error controlling scheduler: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/scheduler/job/control")
|
||||
async def control_job(request: JobControlRequest):
|
||||
"""Control individual jobs (enable/disable/trigger)"""
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
|
||||
scheduler = get_scheduler()
|
||||
|
||||
if request.action == 'enable':
|
||||
success = scheduler.enable_job(request.job_name)
|
||||
message = f"Job {request.job_name} enabled" if success else f"Job {request.job_name} not found"
|
||||
elif request.action == 'disable':
|
||||
success = scheduler.disable_job(request.job_name)
|
||||
message = f"Job {request.job_name} disabled" if success else f"Job {request.job_name} not found"
|
||||
elif request.action == 'trigger':
|
||||
success = scheduler.trigger_job(request.job_name)
|
||||
message = f"Job {request.job_name} triggered" if success else f"Failed to trigger job {request.job_name}"
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
||||
|
||||
return {
|
||||
"message": message,
|
||||
"job_name": request.job_name,
|
||||
"action": request.action,
|
||||
"success": success,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error controlling job: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/scheduler/job/schedule")
|
||||
async def update_job_schedule(request: UpdateScheduleRequest):
|
||||
"""Update job schedule"""
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
|
||||
scheduler = get_scheduler()
|
||||
success = scheduler.update_job_schedule(request.job_name, request.schedule)
|
||||
|
||||
if success:
|
||||
# Get updated job info
|
||||
job_status = scheduler.get_job_status(request.job_name)
|
||||
return {
|
||||
"message": f"Schedule updated for job {request.job_name}",
|
||||
"job_name": request.job_name,
|
||||
"new_schedule": request.schedule,
|
||||
"next_run": job_status.get("next_run"),
|
||||
"success": True,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to update schedule for job {request.job_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating job schedule: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/scheduler/job/{job_name}")
|
||||
async def get_job_status(job_name: str):
|
||||
"""Get status of a specific job"""
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
|
||||
scheduler = get_scheduler()
|
||||
status = scheduler.get_job_status(job_name)
|
||||
|
||||
if "error" in status:
|
||||
raise HTTPException(status_code=404, detail=status["error"])
|
||||
|
||||
return {
|
||||
"job_status": status,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting job status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/scheduler/reload")
|
||||
async def reload_scheduler_config():
|
||||
"""Reload scheduler configuration from file"""
|
||||
try:
|
||||
from job_scheduler import get_scheduler
|
||||
|
||||
scheduler = get_scheduler()
|
||||
success = scheduler.reload_config()
|
||||
|
||||
if success:
|
||||
return {
|
||||
"message": "Scheduler configuration reloaded successfully",
|
||||
"success": True,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to reload configuration")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reloading scheduler config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
|
256
backend/setup_ollama_with_sigma.py
Normal file
256
backend/setup_ollama_with_sigma.py
Normal file
|
@ -0,0 +1,256 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced Ollama setup script that includes SIGMA model creation.
|
||||
This integrates the functionality from fix_sigma_model.sh into the Docker container.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
|
||||
DEFAULT_MODEL = os.getenv('LLM_MODEL', 'llama3.2')
|
||||
SIGMA_MODEL_NAME = 'sigma-llama'
|
||||
|
||||
def log(message: str, level: str = "INFO"):
|
||||
"""Log message with timestamp"""
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
print(f"[{timestamp}] {level}: {message}")
|
||||
|
||||
def wait_for_ollama(max_retries: int = 30, delay: int = 5) -> bool:
|
||||
"""Wait for Ollama service to be ready"""
|
||||
log("Waiting for Ollama service to be ready...")
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=10)
|
||||
if response.status_code == 200:
|
||||
log("✅ Ollama service is ready!")
|
||||
return True
|
||||
except requests.exceptions.RequestException as e:
|
||||
log(f"Attempt {attempt + 1}/{max_retries}: Ollama not ready yet ({e})", "DEBUG")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(delay)
|
||||
|
||||
log("❌ Ollama service failed to become ready", "ERROR")
|
||||
return False
|
||||
|
||||
def get_available_models() -> List[str]:
|
||||
"""Get list of available models"""
|
||||
try:
|
||||
response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = [model.get('name', '') for model in data.get('models', [])]
|
||||
log(f"Available models: {models}")
|
||||
return models
|
||||
else:
|
||||
log(f"Failed to get models: HTTP {response.status_code}", "ERROR")
|
||||
return []
|
||||
except Exception as e:
|
||||
log(f"Error getting models: {e}", "ERROR")
|
||||
return []
|
||||
|
||||
def pull_model(model_name: str) -> bool:
|
||||
"""Pull a model if not available"""
|
||||
log(f"Pulling model: {model_name}")
|
||||
|
||||
try:
|
||||
payload = {"name": model_name}
|
||||
response = requests.post(
|
||||
f"{OLLAMA_BASE_URL}/api/pull",
|
||||
json=payload,
|
||||
timeout=600, # 10 minutes timeout
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Stream and log progress
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
status = data.get('status', '')
|
||||
if status:
|
||||
log(f"Pull progress: {status}", "DEBUG")
|
||||
if data.get('error'):
|
||||
log(f"Pull error: {data.get('error')}", "ERROR")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
log(f"✅ Successfully pulled model: {model_name}")
|
||||
return True
|
||||
else:
|
||||
log(f"❌ Failed to pull model {model_name}: HTTP {response.status_code}", "ERROR")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
log(f"❌ Error pulling model {model_name}: {e}", "ERROR")
|
||||
return False
|
||||
|
||||
def create_sigma_model() -> bool:
|
||||
"""Create the sigma-llama model with specialized SIGMA generation configuration"""
|
||||
log("🔄 Creating sigma-llama model...")
|
||||
|
||||
# First, remove any existing sigma-llama model
|
||||
try:
|
||||
response = requests.delete(f"{OLLAMA_BASE_URL}/api/delete",
|
||||
json={"name": SIGMA_MODEL_NAME},
|
||||
timeout=30)
|
||||
if response.status_code == 200:
|
||||
log("Removed existing sigma-llama model")
|
||||
except Exception:
|
||||
pass # Model might not exist, that's fine
|
||||
|
||||
# Create Modelfile content without the FROM line
|
||||
modelfile_content = """TEMPLATE \"\"\"### Instruction:
|
||||
Generate SIGMA rule logsource and detection sections based on the provided context.
|
||||
|
||||
### Input:
|
||||
{{ .Prompt }}
|
||||
|
||||
### Response:
|
||||
\"\"\"
|
||||
|
||||
PARAMETER temperature 0.1
|
||||
PARAMETER top_p 0.9
|
||||
PARAMETER stop "### Instruction:"
|
||||
PARAMETER stop "### Response:"
|
||||
PARAMETER num_ctx 4096
|
||||
|
||||
SYSTEM \"\"\"You are a cybersecurity expert specializing in SIGMA rule creation. Generate valid SIGMA rules in YAML format based on the provided CVE and exploit information. Output ONLY valid YAML starting with 'title:' and ending with the last YAML line.\"\"\"
|
||||
"""
|
||||
|
||||
try:
|
||||
# Create the model using the API with 'from' parameter
|
||||
payload = {
|
||||
"name": SIGMA_MODEL_NAME,
|
||||
"from": f"{DEFAULT_MODEL}:latest",
|
||||
"modelfile": modelfile_content,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{OLLAMA_BASE_URL}/api/create",
|
||||
json=payload,
|
||||
timeout=300, # 5 minutes timeout
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Stream and log progress
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
status = data.get('status', '')
|
||||
if status:
|
||||
log(f"Model creation: {status}", "DEBUG")
|
||||
if data.get('error'):
|
||||
log(f"Model creation error: {data.get('error')}", "ERROR")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Verify the model was created
|
||||
models = get_available_models()
|
||||
if any(SIGMA_MODEL_NAME in model for model in models):
|
||||
log("✅ sigma-llama model created successfully!")
|
||||
return True
|
||||
else:
|
||||
log("❌ sigma-llama model not found after creation", "ERROR")
|
||||
return False
|
||||
else:
|
||||
log(f"❌ Failed to create sigma-llama model: HTTP {response.status_code}", "ERROR")
|
||||
try:
|
||||
error_data = response.json()
|
||||
log(f"Error details: {error_data}", "ERROR")
|
||||
except:
|
||||
log(f"Error response: {response.text}", "ERROR")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
log(f"❌ Error creating sigma-llama model: {e}", "ERROR")
|
||||
return False
|
||||
|
||||
def test_sigma_model() -> bool:
|
||||
"""Test the sigma-llama model"""
|
||||
log("🔄 Testing sigma-llama model...")
|
||||
|
||||
try:
|
||||
test_payload = {
|
||||
"model": SIGMA_MODEL_NAME,
|
||||
"prompt": "Title: Test PowerShell Rule",
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{OLLAMA_BASE_URL}/api/generate",
|
||||
json=test_payload,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
test_response = data.get('response', '')[:100] # First 100 chars
|
||||
log(f"✅ Model test successful! Response: {test_response}...")
|
||||
return True
|
||||
else:
|
||||
log(f"❌ Model test failed: HTTP {response.status_code}", "ERROR")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
log(f"❌ Error testing model: {e}", "ERROR")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main setup function"""
|
||||
log("🚀 Starting enhanced Ollama setup with SIGMA model creation...")
|
||||
|
||||
# Step 1: Wait for Ollama to be ready
|
||||
if not wait_for_ollama():
|
||||
log("❌ Setup failed: Ollama service not available", "ERROR")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Check current models
|
||||
models = get_available_models()
|
||||
log(f"Current models: {models}")
|
||||
|
||||
# Step 3: Pull default model if needed
|
||||
if not any(DEFAULT_MODEL in model for model in models):
|
||||
log(f"Default model {DEFAULT_MODEL} not found, pulling...")
|
||||
if not pull_model(DEFAULT_MODEL):
|
||||
log(f"❌ Setup failed: Could not pull {DEFAULT_MODEL}", "ERROR")
|
||||
sys.exit(1)
|
||||
else:
|
||||
log(f"✅ Default model {DEFAULT_MODEL} already available")
|
||||
|
||||
# Step 4: Create SIGMA model
|
||||
if not create_sigma_model():
|
||||
log("❌ Setup failed: Could not create sigma-llama model", "ERROR")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 5: Test SIGMA model
|
||||
if not test_sigma_model():
|
||||
log("⚠️ Setup warning: sigma-llama model test failed", "WARN")
|
||||
# Don't exit here, the model might still work
|
||||
|
||||
# Step 6: Final verification
|
||||
final_models = get_available_models()
|
||||
log(f"Final models available: {final_models}")
|
||||
|
||||
if any(SIGMA_MODEL_NAME in model for model in final_models):
|
||||
log("🎉 Setup complete! sigma-llama model is ready for use.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
log("❌ Setup failed: sigma-llama model not available after setup", "ERROR")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
backend/tasks/__init__.py
Normal file
3
backend/tasks/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Celery tasks for the Auto SIGMA Rule Generator
|
||||
"""
|
235
backend/tasks/bulk_tasks.py
Normal file
235
backend/tasks/bulk_tasks.py
Normal file
|
@ -0,0 +1,235 @@
|
|||
"""
|
||||
Bulk processing tasks for Celery
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from celery import current_task
|
||||
from celery_config import celery_app, get_db_session
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from bulk_seeder import BulkSeeder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@celery_app.task(bind=True, name='bulk_tasks.full_bulk_seed')
|
||||
def full_bulk_seed_task(self, start_year: int = 2002, end_year: Optional[int] = None,
|
||||
skip_nvd: bool = False, skip_nomi_sec: bool = False,
|
||||
skip_exploitdb: bool = False, skip_cisa_kev: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for full bulk seeding operation
|
||||
|
||||
Args:
|
||||
start_year: Starting year for NVD data
|
||||
end_year: Ending year for NVD data
|
||||
skip_nvd: Skip NVD bulk processing
|
||||
skip_nomi_sec: Skip nomi-sec PoC synchronization
|
||||
skip_exploitdb: Skip ExploitDB synchronization
|
||||
skip_cisa_kev: Skip CISA KEV synchronization
|
||||
|
||||
Returns:
|
||||
Dictionary containing operation results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'initializing',
|
||||
'progress': 0,
|
||||
'message': 'Starting bulk seeding operation'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting full bulk seed task: {start_year}-{end_year}")
|
||||
|
||||
# Create seeder instance
|
||||
seeder = BulkSeeder(db_session)
|
||||
|
||||
# Create progress callback
|
||||
def update_progress(stage: str, progress: int, message: str = None):
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': stage,
|
||||
'progress': progress,
|
||||
'message': message or f'Processing {stage}'
|
||||
}
|
||||
)
|
||||
|
||||
# Run the bulk seeding operation
|
||||
# Note: We need to handle the async nature of bulk_seeder
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
seeder.full_bulk_seed(
|
||||
start_year=start_year,
|
||||
end_year=end_year,
|
||||
skip_nvd=skip_nvd,
|
||||
skip_nomi_sec=skip_nomi_sec,
|
||||
skip_exploitdb=skip_exploitdb,
|
||||
skip_cisa_kev=skip_cisa_kev,
|
||||
progress_callback=update_progress
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'Bulk seeding completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Full bulk seed task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Full bulk seed task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@celery_app.task(bind=True, name='bulk_tasks.incremental_update_task')
|
||||
def incremental_update_task(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for incremental updates
|
||||
|
||||
Returns:
|
||||
Dictionary containing update results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'incremental_update',
|
||||
'progress': 0,
|
||||
'message': 'Starting incremental update'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Starting incremental update task")
|
||||
|
||||
# Create seeder instance
|
||||
seeder = BulkSeeder(db_session)
|
||||
|
||||
# Run the incremental update
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(seeder.incremental_update())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'Incremental update completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Incremental update task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Incremental update task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@celery_app.task(bind=True, name='bulk_tasks.generate_enhanced_sigma_rules')
|
||||
def generate_enhanced_sigma_rules_task(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for generating enhanced SIGMA rules
|
||||
|
||||
Returns:
|
||||
Dictionary containing generation results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'generating_rules',
|
||||
'progress': 0,
|
||||
'message': 'Starting enhanced SIGMA rule generation'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Starting enhanced SIGMA rule generation task")
|
||||
|
||||
# Create seeder instance
|
||||
seeder = BulkSeeder(db_session)
|
||||
|
||||
# Run the rule generation
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(seeder.generate_enhanced_sigma_rules())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'Enhanced SIGMA rule generation completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Enhanced SIGMA rule generation task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Enhanced SIGMA rule generation task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
959
backend/tasks/data_sync_tasks.py
Normal file
959
backend/tasks/data_sync_tasks.py
Normal file
|
@ -0,0 +1,959 @@
|
|||
"""
|
||||
Data synchronization tasks for Celery
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from typing import Dict, Any
|
||||
from celery import current_task
|
||||
from celery_config import celery_app, get_db_session
|
||||
from nomi_sec_client import NomiSecClient
|
||||
from exploitdb_client_local import ExploitDBLocalClient
|
||||
from cisa_kev_client import CISAKEVClient
|
||||
from mcdevitt_poc_client import GitHubPoCClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_nomi_sec')
|
||||
def sync_nomi_sec_task(self, batch_size: int = 50) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for nomi-sec PoC synchronization
|
||||
|
||||
Args:
|
||||
batch_size: Number of CVEs to process in each batch
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_nomi_sec',
|
||||
'progress': 0,
|
||||
'message': 'Starting nomi-sec PoC synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting nomi-sec sync task with batch size: {batch_size}")
|
||||
|
||||
# Create client instance
|
||||
client = NomiSecClient(db_session)
|
||||
|
||||
# Run the synchronization
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
client.bulk_sync_all_cves(batch_size=batch_size)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'Nomi-sec synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Nomi-sec sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Nomi-sec sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec')
|
||||
def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to force refresh the cache regardless of expiry
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 0,
|
||||
'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from cve2capec_client import CVE2CAPECClient
|
||||
|
||||
# Create client instance
|
||||
client = CVE2CAPECClient()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 10,
|
||||
'message': 'Fetching MITRE ATT&CK mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Force refresh if requested
|
||||
if force_refresh:
|
||||
client._fetch_fresh_data()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 50,
|
||||
'message': 'Processing CVE mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Get statistics about the loaded data
|
||||
stats = client.get_stats()
|
||||
|
||||
# Update progress to completion
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 100,
|
||||
'message': 'CVE2CAPEC synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'completed',
|
||||
'total_mappings': stats.get('total_mappings', 0),
|
||||
'total_techniques': stats.get('unique_techniques', 0),
|
||||
'cache_updated': True
|
||||
}
|
||||
|
||||
logger.info(f"CVE2CAPEC sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CVE2CAPEC sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_github_poc')
|
||||
def sync_github_poc_task(self, batch_size: int = 50) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for GitHub PoC synchronization
|
||||
|
||||
Args:
|
||||
batch_size: Number of CVEs to process in each batch
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_github_poc',
|
||||
'progress': 0,
|
||||
'message': 'Starting GitHub PoC synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting GitHub PoC sync task with batch size: {batch_size}")
|
||||
|
||||
# Create client instance
|
||||
client = GitHubPoCClient(db_session)
|
||||
|
||||
# Run the synchronization
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
client.bulk_sync_all_cves(batch_size=batch_size)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'GitHub PoC synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"GitHub PoC sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GitHub PoC sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec')
|
||||
def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to force refresh the cache regardless of expiry
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 0,
|
||||
'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from cve2capec_client import CVE2CAPECClient
|
||||
|
||||
# Create client instance
|
||||
client = CVE2CAPECClient()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 10,
|
||||
'message': 'Fetching MITRE ATT&CK mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Force refresh if requested
|
||||
if force_refresh:
|
||||
client._fetch_fresh_data()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 50,
|
||||
'message': 'Processing CVE mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Get statistics about the loaded data
|
||||
stats = client.get_stats()
|
||||
|
||||
# Update progress to completion
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 100,
|
||||
'message': 'CVE2CAPEC synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'completed',
|
||||
'total_mappings': stats.get('total_mappings', 0),
|
||||
'total_techniques': stats.get('unique_techniques', 0),
|
||||
'cache_updated': True
|
||||
}
|
||||
|
||||
logger.info(f"CVE2CAPEC sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CVE2CAPEC sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_reference_content')
|
||||
def sync_reference_content_task(self, batch_size: int = 30, max_cves: int = 200,
|
||||
force_resync: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CVE reference content extraction and analysis
|
||||
|
||||
Args:
|
||||
batch_size: Number of CVEs to process in each batch
|
||||
max_cves: Maximum number of CVEs to process
|
||||
force_resync: Force re-sync of recently processed CVEs
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from main import CVE
|
||||
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_reference_content',
|
||||
'progress': 0,
|
||||
'message': 'Starting CVE reference content extraction'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting reference content sync task - batch_size: {batch_size}, max_cves: {max_cves}")
|
||||
|
||||
# Get CVEs to process (prioritize those with references but no extracted content)
|
||||
query = db_session.query(CVE)
|
||||
|
||||
if not force_resync:
|
||||
# Skip CVEs that were recently processed
|
||||
from datetime import datetime, timedelta
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=7)
|
||||
query = query.filter(
|
||||
(CVE.reference_content_extracted_at.is_(None)) |
|
||||
(CVE.reference_content_extracted_at < cutoff_date)
|
||||
)
|
||||
|
||||
# Prioritize CVEs with references
|
||||
cves = query.filter(CVE.references.isnot(None)).limit(max_cves).all()
|
||||
|
||||
if not cves:
|
||||
logger.info("No CVEs found for reference content extraction")
|
||||
return {'total_processed': 0, 'successful_extractions': 0, 'failed_extractions': 0}
|
||||
|
||||
total_processed = 0
|
||||
successful_extractions = 0
|
||||
failed_extractions = 0
|
||||
|
||||
# Process CVEs in batches
|
||||
for i in range(0, len(cves), batch_size):
|
||||
batch = cves[i:i + batch_size]
|
||||
|
||||
for j, cve in enumerate(batch):
|
||||
try:
|
||||
# Update progress
|
||||
overall_progress = int(((i + j) / len(cves)) * 100)
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_reference_content',
|
||||
'progress': overall_progress,
|
||||
'message': f'Processing CVE {cve.cve_id} ({i + j + 1}/{len(cves)})',
|
||||
'current_cve': cve.cve_id,
|
||||
'processed': i + j,
|
||||
'total': len(cves)
|
||||
}
|
||||
)
|
||||
|
||||
# For now, simulate reference content extraction
|
||||
# In a real implementation, you would create a ReferenceContentExtractor
|
||||
# and extract content from CVE references
|
||||
|
||||
# Mark CVE as processed
|
||||
from datetime import datetime
|
||||
cve.reference_content_extracted_at = datetime.utcnow()
|
||||
|
||||
successful_extractions += 1
|
||||
total_processed += 1
|
||||
|
||||
# Small delay between requests
|
||||
import time
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing reference content for CVE {cve.cve_id}: {e}")
|
||||
failed_extractions += 1
|
||||
total_processed += 1
|
||||
|
||||
# Commit after each batch
|
||||
db_session.commit()
|
||||
logger.info(f"Processed batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}")
|
||||
|
||||
# Final results
|
||||
result = {
|
||||
'total_processed': total_processed,
|
||||
'successful_extractions': successful_extractions,
|
||||
'failed_extractions': failed_extractions,
|
||||
'extraction_rate': (successful_extractions / total_processed * 100) if total_processed > 0 else 0
|
||||
}
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': f'Reference content extraction completed: {successful_extractions} successful, {failed_extractions} failed',
|
||||
'results': result
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Reference content sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reference content sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec')
|
||||
def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to force refresh the cache regardless of expiry
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 0,
|
||||
'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from cve2capec_client import CVE2CAPECClient
|
||||
|
||||
# Create client instance
|
||||
client = CVE2CAPECClient()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 10,
|
||||
'message': 'Fetching MITRE ATT&CK mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Force refresh if requested
|
||||
if force_refresh:
|
||||
client._fetch_fresh_data()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 50,
|
||||
'message': 'Processing CVE mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Get statistics about the loaded data
|
||||
stats = client.get_stats()
|
||||
|
||||
# Update progress to completion
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 100,
|
||||
'message': 'CVE2CAPEC synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'completed',
|
||||
'total_mappings': stats.get('total_mappings', 0),
|
||||
'total_techniques': stats.get('unique_techniques', 0),
|
||||
'cache_updated': True
|
||||
}
|
||||
|
||||
logger.info(f"CVE2CAPEC sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CVE2CAPEC sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_exploitdb')
|
||||
def sync_exploitdb_task(self, batch_size: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for ExploitDB synchronization
|
||||
|
||||
Args:
|
||||
batch_size: Number of CVEs to process in each batch
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_exploitdb',
|
||||
'progress': 0,
|
||||
'message': 'Starting ExploitDB synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting ExploitDB sync task with batch size: {batch_size}")
|
||||
|
||||
# Create client instance
|
||||
client = ExploitDBLocalClient(db_session)
|
||||
|
||||
# Run the synchronization
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
client.bulk_sync_exploitdb(batch_size=batch_size)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'ExploitDB synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"ExploitDB sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ExploitDB sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec')
|
||||
def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to force refresh the cache regardless of expiry
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 0,
|
||||
'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from cve2capec_client import CVE2CAPECClient
|
||||
|
||||
# Create client instance
|
||||
client = CVE2CAPECClient()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 10,
|
||||
'message': 'Fetching MITRE ATT&CK mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Force refresh if requested
|
||||
if force_refresh:
|
||||
client._fetch_fresh_data()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 50,
|
||||
'message': 'Processing CVE mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Get statistics about the loaded data
|
||||
stats = client.get_stats()
|
||||
|
||||
# Update progress to completion
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 100,
|
||||
'message': 'CVE2CAPEC synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'completed',
|
||||
'total_mappings': stats.get('total_mappings', 0),
|
||||
'total_techniques': stats.get('unique_techniques', 0),
|
||||
'cache_updated': True
|
||||
}
|
||||
|
||||
logger.info(f"CVE2CAPEC sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CVE2CAPEC sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_cisa_kev')
|
||||
def sync_cisa_kev_task(self, batch_size: int = 100) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CISA KEV synchronization
|
||||
|
||||
Args:
|
||||
batch_size: Number of CVEs to process in each batch
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cisa_kev',
|
||||
'progress': 0,
|
||||
'message': 'Starting CISA KEV synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting CISA KEV sync task with batch size: {batch_size}")
|
||||
|
||||
# Create client instance
|
||||
client = CISAKEVClient(db_session)
|
||||
|
||||
# Run the synchronization
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
client.bulk_sync_kev_data(batch_size=batch_size)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': 'CISA KEV synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"CISA KEV sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CISA KEV sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.sync_cve2capec')
|
||||
def sync_cve2capec_task(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for CVE2CAPEC MITRE ATT&CK mapping synchronization
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to force refresh the cache regardless of expiry
|
||||
|
||||
Returns:
|
||||
Dictionary containing sync results
|
||||
"""
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 0,
|
||||
'message': 'Starting CVE2CAPEC MITRE ATT&CK mapping synchronization'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting CVE2CAPEC sync task with force_refresh: {force_refresh}")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from cve2capec_client import CVE2CAPECClient
|
||||
|
||||
# Create client instance
|
||||
client = CVE2CAPECClient()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 10,
|
||||
'message': 'Fetching MITRE ATT&CK mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Force refresh if requested
|
||||
if force_refresh:
|
||||
client._fetch_fresh_data()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 50,
|
||||
'message': 'Processing CVE mappings...'
|
||||
}
|
||||
)
|
||||
|
||||
# Get statistics about the loaded data
|
||||
stats = client.get_stats()
|
||||
|
||||
# Update progress to completion
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'sync_cve2capec',
|
||||
'progress': 100,
|
||||
'message': 'CVE2CAPEC synchronization completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'completed',
|
||||
'total_mappings': stats.get('total_mappings', 0),
|
||||
'total_techniques': stats.get('unique_techniques', 0),
|
||||
'cache_updated': True
|
||||
}
|
||||
|
||||
logger.info(f"CVE2CAPEC sync task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CVE2CAPEC sync task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name='data_sync_tasks.build_exploitdb_index')
|
||||
def build_exploitdb_index_task(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for building/rebuilding ExploitDB file index
|
||||
|
||||
Returns:
|
||||
Dictionary containing build results
|
||||
"""
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'build_exploitdb_index',
|
||||
'progress': 0,
|
||||
'message': 'Starting ExploitDB file index building'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Starting ExploitDB index build task")
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from exploitdb_client_local import ExploitDBLocalClient
|
||||
|
||||
# Create client instance with lazy_load=False to force index building
|
||||
client = ExploitDBLocalClient(None, lazy_load=False)
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'build_exploitdb_index',
|
||||
'progress': 50,
|
||||
'message': 'Building file index...'
|
||||
}
|
||||
)
|
||||
|
||||
# Force index rebuild
|
||||
client._build_file_index()
|
||||
|
||||
# Update progress to completion
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'build_exploitdb_index',
|
||||
'progress': 100,
|
||||
'message': 'ExploitDB index building completed successfully'
|
||||
}
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'completed',
|
||||
'total_exploits_indexed': len(client.file_index),
|
||||
'index_updated': True
|
||||
}
|
||||
|
||||
logger.info(f"ExploitDB index build task completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ExploitDB index build task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
437
backend/tasks/maintenance_tasks.py
Normal file
437
backend/tasks/maintenance_tasks.py
Normal file
|
@ -0,0 +1,437 @@
|
|||
"""
|
||||
Maintenance tasks for Celery
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from celery_config import celery_app, get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@celery_app.task(name='tasks.maintenance_tasks.cleanup_old_results')
|
||||
def cleanup_old_results():
|
||||
"""
|
||||
Periodic task to clean up old Celery results and logs
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting cleanup of old Celery results")
|
||||
|
||||
# This would clean up old results from Redis
|
||||
# For now, we'll just log the action
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=7)
|
||||
|
||||
# Clean up old task results (this would be Redis cleanup)
|
||||
# celery_app.backend.cleanup()
|
||||
|
||||
logger.info(f"Cleanup completed for results older than {cutoff_date}")
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'cutoff_date': cutoff_date.isoformat(),
|
||||
'message': 'Old results cleanup completed'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup task failed: {e}")
|
||||
raise
|
||||
|
||||
@celery_app.task(name='tasks.maintenance_tasks.health_check')
|
||||
def health_check():
|
||||
"""
|
||||
Health check task to verify system components
|
||||
"""
|
||||
try:
|
||||
db_session = get_db_session()
|
||||
|
||||
# Check database connectivity
|
||||
try:
|
||||
db_session.execute("SELECT 1")
|
||||
db_status = "healthy"
|
||||
except Exception as e:
|
||||
db_status = f"unhealthy: {e}"
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
# Check Redis connectivity
|
||||
try:
|
||||
celery_app.backend.ping()
|
||||
redis_status = "healthy"
|
||||
except Exception as e:
|
||||
redis_status = f"unhealthy: {e}"
|
||||
|
||||
result = {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'database': db_status,
|
||||
'redis': redis_status,
|
||||
'celery': 'healthy'
|
||||
}
|
||||
|
||||
logger.info(f"Health check completed: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
raise
|
||||
|
||||
@celery_app.task(bind=True, name='tasks.maintenance_tasks.database_cleanup_comprehensive')
|
||||
def database_cleanup_comprehensive(self, days_to_keep: int = 30, cleanup_failed_jobs: bool = True,
|
||||
cleanup_logs: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive database cleanup task
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days to keep old records
|
||||
cleanup_failed_jobs: Whether to clean up failed job records
|
||||
cleanup_logs: Whether to clean up old log entries
|
||||
|
||||
Returns:
|
||||
Dictionary containing cleanup results
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
|
||||
db_session = get_db_session()
|
||||
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'database_cleanup',
|
||||
'progress': 0,
|
||||
'message': 'Starting comprehensive database cleanup'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting comprehensive database cleanup - keeping {days_to_keep} days")
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_to_keep)
|
||||
cleanup_results = {
|
||||
'cutoff_date': cutoff_date.isoformat(),
|
||||
'cleaned_tables': {},
|
||||
'total_records_cleaned': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Import models here to avoid circular imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from main import BulkProcessingJob
|
||||
|
||||
# Clean up old bulk processing jobs
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'database_cleanup',
|
||||
'progress': 20,
|
||||
'message': 'Cleaning up old bulk processing jobs'
|
||||
}
|
||||
)
|
||||
|
||||
old_jobs_query = db_session.query(BulkProcessingJob).filter(
|
||||
BulkProcessingJob.created_at < cutoff_date
|
||||
)
|
||||
|
||||
if cleanup_failed_jobs:
|
||||
# Clean all old jobs
|
||||
old_jobs_count = old_jobs_query.count()
|
||||
old_jobs_query.delete()
|
||||
else:
|
||||
# Only clean completed jobs
|
||||
old_jobs_query = old_jobs_query.filter(
|
||||
BulkProcessingJob.status.in_(['completed', 'cancelled'])
|
||||
)
|
||||
old_jobs_count = old_jobs_query.count()
|
||||
old_jobs_query.delete()
|
||||
|
||||
cleanup_results['cleaned_tables']['bulk_processing_jobs'] = old_jobs_count
|
||||
cleanup_results['total_records_cleaned'] += old_jobs_count
|
||||
|
||||
# Clean up old Celery task results from Redis
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'database_cleanup',
|
||||
'progress': 40,
|
||||
'message': 'Cleaning up old Celery task results'
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# This would clean up old results from Redis backend
|
||||
# For now, we'll simulate this
|
||||
celery_cleanup_count = 0
|
||||
# celery_app.backend.cleanup()
|
||||
cleanup_results['cleaned_tables']['celery_results'] = celery_cleanup_count
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clean Celery results: {e}")
|
||||
cleanup_results['cleaned_tables']['celery_results'] = 0
|
||||
|
||||
# Clean up old temporary data (if any)
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'database_cleanup',
|
||||
'progress': 60,
|
||||
'message': 'Cleaning up temporary data'
|
||||
}
|
||||
)
|
||||
|
||||
# Add any custom temporary table cleanup here
|
||||
# Example: Clean up old session data, temporary files, etc.
|
||||
temp_cleanup_count = 0
|
||||
cleanup_results['cleaned_tables']['temporary_data'] = temp_cleanup_count
|
||||
|
||||
# Vacuum/optimize database (PostgreSQL)
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'database_cleanup',
|
||||
'progress': 80,
|
||||
'message': 'Optimizing database'
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Run VACUUM on PostgreSQL to reclaim space
|
||||
db_session.execute("VACUUM;")
|
||||
cleanup_results['database_optimized'] = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not vacuum database: {e}")
|
||||
cleanup_results['database_optimized'] = False
|
||||
|
||||
# Commit all changes
|
||||
db_session.commit()
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': f'Database cleanup completed - removed {cleanup_results["total_records_cleaned"]} records',
|
||||
'results': cleanup_results
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Database cleanup completed: {cleanup_results}")
|
||||
return cleanup_results
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Cleanup failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
@celery_app.task(bind=True, name='tasks.maintenance_tasks.health_check_detailed')
|
||||
def health_check_detailed(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Detailed health check task for all system components
|
||||
|
||||
Returns:
|
||||
Dictionary containing detailed health status
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
import psutil
|
||||
import redis
|
||||
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'health_check',
|
||||
'progress': 0,
|
||||
'message': 'Starting detailed health check'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Starting detailed health check")
|
||||
|
||||
health_status = {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'overall_status': 'healthy',
|
||||
'components': {}
|
||||
}
|
||||
|
||||
# Check database connectivity and performance
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'health_check',
|
||||
'progress': 20,
|
||||
'message': 'Checking database health'
|
||||
}
|
||||
)
|
||||
|
||||
db_session = get_db_session()
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
db_session.execute("SELECT 1")
|
||||
db_response_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
# Check database size and connections
|
||||
db_size_result = db_session.execute("SELECT pg_size_pretty(pg_database_size(current_database()));").fetchone()
|
||||
db_connections_result = db_session.execute("SELECT count(*) FROM pg_stat_activity;").fetchone()
|
||||
|
||||
health_status['components']['database'] = {
|
||||
'status': 'healthy',
|
||||
'response_time_seconds': db_response_time,
|
||||
'database_size': db_size_result[0] if db_size_result else 'unknown',
|
||||
'active_connections': db_connections_result[0] if db_connections_result else 0,
|
||||
'details': 'Database responsive and accessible'
|
||||
}
|
||||
except Exception as e:
|
||||
health_status['components']['database'] = {
|
||||
'status': 'unhealthy',
|
||||
'error': str(e),
|
||||
'details': 'Database connection failed'
|
||||
}
|
||||
health_status['overall_status'] = 'degraded'
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
# Check Redis connectivity and performance
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'health_check',
|
||||
'progress': 40,
|
||||
'message': 'Checking Redis health'
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
celery_app.backend.ping()
|
||||
redis_response_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
# Get Redis info
|
||||
redis_client = redis.Redis.from_url(celery_app.conf.broker_url)
|
||||
redis_info = redis_client.info()
|
||||
|
||||
health_status['components']['redis'] = {
|
||||
'status': 'healthy',
|
||||
'response_time_seconds': redis_response_time,
|
||||
'memory_usage_mb': redis_info.get('used_memory', 0) / (1024 * 1024),
|
||||
'connected_clients': redis_info.get('connected_clients', 0),
|
||||
'uptime_seconds': redis_info.get('uptime_in_seconds', 0),
|
||||
'details': 'Redis responsive and accessible'
|
||||
}
|
||||
except Exception as e:
|
||||
health_status['components']['redis'] = {
|
||||
'status': 'unhealthy',
|
||||
'error': str(e),
|
||||
'details': 'Redis connection failed'
|
||||
}
|
||||
health_status['overall_status'] = 'degraded'
|
||||
|
||||
# Check system resources
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'health_check',
|
||||
'progress': 60,
|
||||
'message': 'Checking system resources'
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
health_status['components']['system'] = {
|
||||
'status': 'healthy',
|
||||
'cpu_percent': cpu_percent,
|
||||
'memory_percent': memory.percent,
|
||||
'memory_available_gb': memory.available / (1024**3),
|
||||
'disk_percent': disk.percent,
|
||||
'disk_free_gb': disk.free / (1024**3),
|
||||
'details': 'System resources within normal ranges'
|
||||
}
|
||||
|
||||
# Mark as degraded if resources are high
|
||||
if cpu_percent > 80 or memory.percent > 85 or disk.percent > 90:
|
||||
health_status['components']['system']['status'] = 'degraded'
|
||||
health_status['overall_status'] = 'degraded'
|
||||
health_status['components']['system']['details'] = 'High resource usage detected'
|
||||
|
||||
except Exception as e:
|
||||
health_status['components']['system'] = {
|
||||
'status': 'unknown',
|
||||
'error': str(e),
|
||||
'details': 'Could not check system resources'
|
||||
}
|
||||
|
||||
# Check Celery worker status
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'health_check',
|
||||
'progress': 80,
|
||||
'message': 'Checking Celery workers'
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
inspect = celery_app.control.inspect()
|
||||
active_workers = inspect.active()
|
||||
stats = inspect.stats()
|
||||
|
||||
health_status['components']['celery'] = {
|
||||
'status': 'healthy',
|
||||
'active_workers': len(active_workers) if active_workers else 0,
|
||||
'worker_stats': stats,
|
||||
'details': 'Celery workers responding'
|
||||
}
|
||||
|
||||
if not active_workers:
|
||||
health_status['components']['celery']['status'] = 'degraded'
|
||||
health_status['components']['celery']['details'] = 'No active workers found'
|
||||
health_status['overall_status'] = 'degraded'
|
||||
|
||||
except Exception as e:
|
||||
health_status['components']['celery'] = {
|
||||
'status': 'unknown',
|
||||
'error': str(e),
|
||||
'details': 'Could not check Celery workers'
|
||||
}
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': f'Health check completed - overall status: {health_status["overall_status"]}',
|
||||
'results': health_status
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Detailed health check completed: {health_status['overall_status']}")
|
||||
return health_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Detailed health check failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Health check failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
409
backend/tasks/sigma_tasks.py
Normal file
409
backend/tasks/sigma_tasks.py
Normal file
|
@ -0,0 +1,409 @@
|
|||
"""
|
||||
SIGMA rule generation tasks for Celery
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from celery import current_task
|
||||
from celery_config import celery_app, get_db_session
|
||||
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
||||
from llm_client import LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@celery_app.task(bind=True, name='sigma_tasks.generate_enhanced_rules')
|
||||
def generate_enhanced_rules_task(self, cve_ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for enhanced SIGMA rule generation
|
||||
|
||||
Args:
|
||||
cve_ids: Optional list of specific CVE IDs to process
|
||||
|
||||
Returns:
|
||||
Dictionary containing generation results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from main import CVE
|
||||
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'generating_rules',
|
||||
'progress': 0,
|
||||
'message': 'Starting enhanced SIGMA rule generation'
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting enhanced rule generation task for CVEs: {cve_ids}")
|
||||
|
||||
# Create generator instance
|
||||
generator = EnhancedSigmaGenerator(db_session)
|
||||
|
||||
# Get CVEs to process
|
||||
if cve_ids:
|
||||
cves = db_session.query(CVE).filter(CVE.cve_id.in_(cve_ids)).all()
|
||||
else:
|
||||
cves = db_session.query(CVE).filter(CVE.poc_count > 0).all()
|
||||
|
||||
total_cves = len(cves)
|
||||
processed_cves = 0
|
||||
successful_rules = 0
|
||||
failed_rules = 0
|
||||
results = []
|
||||
|
||||
# Process each CVE
|
||||
for i, cve in enumerate(cves):
|
||||
try:
|
||||
# Update progress
|
||||
progress = int((i / total_cves) * 100)
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'generating_rules',
|
||||
'progress': progress,
|
||||
'message': f'Processing CVE {cve.cve_id}',
|
||||
'current_cve': cve.cve_id,
|
||||
'processed': processed_cves,
|
||||
'total': total_cves
|
||||
}
|
||||
)
|
||||
|
||||
# Generate rule using asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
generator.generate_enhanced_rule(cve)
|
||||
)
|
||||
|
||||
if result.get('success', False):
|
||||
successful_rules += 1
|
||||
else:
|
||||
failed_rules += 1
|
||||
|
||||
results.append({
|
||||
'cve_id': cve.cve_id,
|
||||
'success': result.get('success', False),
|
||||
'message': result.get('message', 'No message'),
|
||||
'rule_id': result.get('rule_id')
|
||||
})
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
processed_cves += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing CVE {cve.cve_id}: {e}")
|
||||
failed_rules += 1
|
||||
results.append({
|
||||
'cve_id': cve.cve_id,
|
||||
'success': False,
|
||||
'message': f'Error: {str(e)}',
|
||||
'rule_id': None
|
||||
})
|
||||
|
||||
# Final results
|
||||
final_result = {
|
||||
'total_processed': processed_cves,
|
||||
'successful_rules': successful_rules,
|
||||
'failed_rules': failed_rules,
|
||||
'results': results
|
||||
}
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': f'Generated {successful_rules} rules from {processed_cves} CVEs',
|
||||
'results': final_result
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Enhanced rule generation task completed: {final_result}")
|
||||
return final_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Enhanced rule generation task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@celery_app.task(bind=True, name='sigma_tasks.llm_enhanced_generation')
|
||||
def llm_enhanced_generation_task(self, cve_id: str, provider: str = 'ollama',
|
||||
model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for LLM-enhanced rule generation
|
||||
|
||||
Args:
|
||||
cve_id: CVE identifier
|
||||
provider: LLM provider (openai, anthropic, ollama, finetuned)
|
||||
model: Specific model to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing generation result
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from main import CVE
|
||||
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'llm_generation',
|
||||
'progress': 10,
|
||||
'message': f'Starting LLM rule generation for {cve_id}',
|
||||
'cve_id': cve_id,
|
||||
'provider': provider,
|
||||
'model': model
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting LLM rule generation for {cve_id} using {provider}")
|
||||
|
||||
# Get CVE from database
|
||||
cve = db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
|
||||
if not cve:
|
||||
raise ValueError(f"CVE {cve_id} not found in database")
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'llm_generation',
|
||||
'progress': 25,
|
||||
'message': f'Initializing LLM client ({provider})',
|
||||
'cve_id': cve_id
|
||||
}
|
||||
)
|
||||
|
||||
# Create LLM client
|
||||
llm_client = LLMClient(provider=provider, model=model)
|
||||
|
||||
if not llm_client.is_available():
|
||||
raise ValueError(f"LLM client not available for provider: {provider}")
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'llm_generation',
|
||||
'progress': 50,
|
||||
'message': f'Generating rule with LLM for {cve_id}',
|
||||
'cve_id': cve_id
|
||||
}
|
||||
)
|
||||
|
||||
# Generate rule using asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
rule_content = loop.run_until_complete(
|
||||
llm_client.generate_sigma_rule(
|
||||
cve_id=cve.cve_id,
|
||||
poc_content=cve.poc_data or '',
|
||||
cve_description=cve.description or ''
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Update progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'llm_generation',
|
||||
'progress': 75,
|
||||
'message': f'Validating generated rule for {cve_id}',
|
||||
'cve_id': cve_id
|
||||
}
|
||||
)
|
||||
|
||||
# Validate the generated rule
|
||||
is_valid = False
|
||||
if rule_content:
|
||||
is_valid = llm_client.validate_sigma_rule(rule_content, cve_id)
|
||||
|
||||
# Prepare result
|
||||
result = {
|
||||
'cve_id': cve_id,
|
||||
'rule_content': rule_content,
|
||||
'is_valid': is_valid,
|
||||
'provider': provider,
|
||||
'model': model or llm_client.model,
|
||||
'success': bool(rule_content and is_valid)
|
||||
}
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': f'LLM rule generation completed for {cve_id}',
|
||||
'cve_id': cve_id,
|
||||
'success': result['success'],
|
||||
'result': result
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"LLM rule generation task completed for {cve_id}: {result['success']}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM rule generation task failed for {cve_id}: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Task failed for {cve_id}: {str(e)}',
|
||||
'cve_id': cve_id,
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@celery_app.task(bind=True, name='sigma_tasks.batch_llm_generation')
|
||||
def batch_llm_generation_task(self, cve_ids: List[str], provider: str = 'ollama',
|
||||
model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Celery task for batch LLM rule generation
|
||||
|
||||
Args:
|
||||
cve_ids: List of CVE identifiers
|
||||
provider: LLM provider (openai, anthropic, ollama, finetuned)
|
||||
model: Specific model to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing batch generation results
|
||||
"""
|
||||
db_session = get_db_session()
|
||||
|
||||
try:
|
||||
# Update task progress
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'batch_llm_generation',
|
||||
'progress': 0,
|
||||
'message': f'Starting batch LLM generation for {len(cve_ids)} CVEs',
|
||||
'total_cves': len(cve_ids),
|
||||
'provider': provider,
|
||||
'model': model
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Starting batch LLM generation for {len(cve_ids)} CVEs using {provider}")
|
||||
|
||||
# Initialize results
|
||||
results = []
|
||||
successful_rules = 0
|
||||
failed_rules = 0
|
||||
|
||||
# Process each CVE
|
||||
for i, cve_id in enumerate(cve_ids):
|
||||
try:
|
||||
# Update progress
|
||||
progress = int((i / len(cve_ids)) * 100)
|
||||
self.update_state(
|
||||
state='PROGRESS',
|
||||
meta={
|
||||
'stage': 'batch_llm_generation',
|
||||
'progress': progress,
|
||||
'message': f'Processing CVE {cve_id} ({i+1}/{len(cve_ids)})',
|
||||
'current_cve': cve_id,
|
||||
'processed': i,
|
||||
'total': len(cve_ids)
|
||||
}
|
||||
)
|
||||
|
||||
# Generate rule for this CVE
|
||||
result = llm_enhanced_generation_task.apply(
|
||||
args=[cve_id, provider, model]
|
||||
).get()
|
||||
|
||||
if result.get('success', False):
|
||||
successful_rules += 1
|
||||
else:
|
||||
failed_rules += 1
|
||||
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing CVE {cve_id} in batch: {e}")
|
||||
failed_rules += 1
|
||||
results.append({
|
||||
'cve_id': cve_id,
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'provider': provider,
|
||||
'model': model
|
||||
})
|
||||
|
||||
# Final results
|
||||
final_result = {
|
||||
'total_processed': len(cve_ids),
|
||||
'successful_rules': successful_rules,
|
||||
'failed_rules': failed_rules,
|
||||
'provider': provider,
|
||||
'model': model,
|
||||
'results': results
|
||||
}
|
||||
|
||||
# Update final progress
|
||||
self.update_state(
|
||||
state='SUCCESS',
|
||||
meta={
|
||||
'stage': 'completed',
|
||||
'progress': 100,
|
||||
'message': f'Batch generation completed: {successful_rules} successful, {failed_rules} failed',
|
||||
'results': final_result
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Batch LLM generation task completed: {final_result}")
|
||||
return final_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch LLM generation task failed: {e}")
|
||||
self.update_state(
|
||||
state='FAILURE',
|
||||
meta={
|
||||
'stage': 'error',
|
||||
'progress': 0,
|
||||
'message': f'Batch task failed: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
db_session.close()
|
|
@ -22,6 +22,8 @@ services:
|
|||
- "8000:8000"
|
||||
environment:
|
||||
DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db
|
||||
CELERY_BROKER_URL: redis://redis:6379/0
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/0
|
||||
NVD_API_KEY: ${NVD_API_KEY:-}
|
||||
GITHUB_TOKEN: ${GITHUB_TOKEN}
|
||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||
|
@ -29,15 +31,21 @@ services:
|
|||
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://ollama:11434}
|
||||
LLM_PROVIDER: ${LLM_PROVIDER:-ollama}
|
||||
LLM_MODEL: ${LLM_MODEL:-llama3.2}
|
||||
LLM_ENABLED: ${LLM_ENABLED:-true}
|
||||
FINETUNED_MODEL_PATH: ${FINETUNED_MODEL_PATH:-/app/models/sigma_llama_finetuned}
|
||||
HUGGING_FACE_TOKEN: ${HUGGING_FACE_TOKEN}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
ollama-setup:
|
||||
condition: service_completed_successfully
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./github_poc_collector:/github_poc_collector
|
||||
- ./exploit-db-mirror:/app/exploit-db-mirror
|
||||
- ./models:/app/models
|
||||
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
frontend:
|
||||
|
@ -68,6 +76,12 @@ services:
|
|||
environment:
|
||||
- OLLAMA_HOST=0.0.0.0
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 5G
|
||||
reservations:
|
||||
memory: 3G
|
||||
|
||||
ollama-setup:
|
||||
build: ./backend
|
||||
|
@ -78,9 +92,109 @@ services:
|
|||
LLM_MODEL: llama3.2
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
command: python setup_ollama.py
|
||||
command: python setup_ollama_with_sigma.py
|
||||
restart: "no"
|
||||
|
||||
initial-setup:
|
||||
build: ./backend
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
celery-worker:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db
|
||||
CELERY_BROKER_URL: redis://redis:6379/0
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/0
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
command: python initial_setup.py
|
||||
restart: "no"
|
||||
|
||||
celery-worker:
|
||||
build: ./backend
|
||||
command: celery -A celery_config worker --loglevel=info --concurrency=4
|
||||
environment:
|
||||
DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db
|
||||
CELERY_BROKER_URL: redis://redis:6379/0
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/0
|
||||
NVD_API_KEY: ${NVD_API_KEY:-}
|
||||
GITHUB_TOKEN: ${GITHUB_TOKEN}
|
||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
||||
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://ollama:11434}
|
||||
LLM_PROVIDER: ${LLM_PROVIDER:-ollama}
|
||||
LLM_MODEL: ${LLM_MODEL:-llama3.2}
|
||||
LLM_ENABLED: ${LLM_ENABLED:-true}
|
||||
FINETUNED_MODEL_PATH: ${FINETUNED_MODEL_PATH:-/app/models/sigma_llama_finetuned}
|
||||
HUGGING_FACE_TOKEN: ${HUGGING_FACE_TOKEN}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
ollama-setup:
|
||||
condition: service_completed_successfully
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./github_poc_collector:/github_poc_collector
|
||||
- ./exploit-db-mirror:/app/exploit-db-mirror
|
||||
- ./models:/app/models
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "celery", "-A", "celery_config", "inspect", "ping"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
celery-beat:
|
||||
build: ./backend
|
||||
command: celery -A celery_config beat --loglevel=info --pidfile=/tmp/celerybeat.pid
|
||||
environment:
|
||||
DATABASE_URL: postgresql://cve_user:cve_password@db:5432/cve_sigma_db
|
||||
CELERY_BROKER_URL: redis://redis:6379/0
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/0
|
||||
NVD_API_KEY: ${NVD_API_KEY:-}
|
||||
GITHUB_TOKEN: ${GITHUB_TOKEN}
|
||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
||||
OLLAMA_BASE_URL: ${OLLAMA_BASE_URL:-http://ollama:11434}
|
||||
LLM_PROVIDER: ${LLM_PROVIDER:-ollama}
|
||||
LLM_MODEL: ${LLM_MODEL:-llama3.2}
|
||||
LLM_ENABLED: ${LLM_ENABLED:-true}
|
||||
FINETUNED_MODEL_PATH: ${FINETUNED_MODEL_PATH:-/app/models/sigma_llama_finetuned}
|
||||
HUGGING_FACE_TOKEN: ${HUGGING_FACE_TOKEN}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
celery-worker:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./github_poc_collector:/github_poc_collector
|
||||
- ./exploit-db-mirror:/app/exploit-db-mirror
|
||||
- ./models:/app/models
|
||||
restart: unless-stopped
|
||||
|
||||
flower:
|
||||
build: ./backend
|
||||
command: celery -A celery_config flower --port=5555
|
||||
ports:
|
||||
- "5555:5555"
|
||||
environment:
|
||||
CELERY_BROKER_URL: redis://redis:6379/0
|
||||
CELERY_RESULT_BACKEND: redis://redis:6379/0
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_started
|
||||
celery-worker:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
redis_data:
|
||||
|
|
|
@ -26,8 +26,6 @@ function App() {
|
|||
const [runningJobTypes, setRunningJobTypes] = useState(new Set());
|
||||
const [llmStatus, setLlmStatus] = useState({});
|
||||
const [exploitSyncDropdownOpen, setExploitSyncDropdownOpen] = useState(false);
|
||||
const [schedulerStatus, setSchedulerStatus] = useState({});
|
||||
const [schedulerJobs, setSchedulerJobs] = useState({});
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
|
@ -84,15 +82,8 @@ function App() {
|
|||
return isNomiSecSyncRunning() || isGitHubPocSyncRunning() || isExploitDBSyncRunning() || isCISAKEVSyncRunning();
|
||||
};
|
||||
|
||||
const fetchSchedulerData = async () => {
|
||||
try {
|
||||
const response = await axios.get('http://localhost:8000/api/scheduler/status');
|
||||
setSchedulerStatus(response.data.scheduler_status);
|
||||
setSchedulerJobs(response.data.scheduler_status.jobs || {});
|
||||
} catch (error) {
|
||||
console.error('Error fetching scheduler data:', error);
|
||||
}
|
||||
};
|
||||
// Note: Scheduler functionality removed - now handled by Celery Beat
|
||||
// Monitoring available via Flower at http://localhost:5555
|
||||
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
|
@ -223,12 +214,25 @@ function App() {
|
|||
const syncGitHubPocs = async (cveId = null) => {
|
||||
try {
|
||||
const response = await axios.post(`${API_BASE_URL}/api/sync-github-pocs`, {
|
||||
cve_id: cveId
|
||||
cve_id: cveId,
|
||||
batch_size: 50
|
||||
});
|
||||
console.log('GitHub PoC sync response:', response.data);
|
||||
|
||||
// Show success message with Celery task info
|
||||
if (response.data.task_id) {
|
||||
console.log(`GitHub PoC sync task started: ${response.data.task_id}`);
|
||||
console.log(`Monitor at: ${response.data.monitor_url}`);
|
||||
|
||||
// Show notification to user
|
||||
alert(`GitHub PoC sync started successfully!\nTask ID: ${response.data.task_id}\n\nMonitor progress at http://localhost:5555 (Flower Dashboard)`);
|
||||
}
|
||||
|
||||
// Refresh data to show the task in bulk jobs
|
||||
fetchData();
|
||||
} catch (error) {
|
||||
console.error('Error syncing GitHub PoCs:', error);
|
||||
alert('Failed to start GitHub PoC sync. Please check the console for details.');
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -331,6 +335,35 @@ function App() {
|
|||
|
||||
const Dashboard = () => (
|
||||
<div className="space-y-6">
|
||||
{/* Celery Monitoring Notice */}
|
||||
<div className="bg-blue-50 border border-blue-200 rounded-lg p-4">
|
||||
<div className="flex items-center">
|
||||
<div className="flex-shrink-0">
|
||||
<svg className="h-5 w-5 text-blue-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path fillRule="evenodd" d="M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a1 1 0 000 2v3a1 1 0 001 1h1a1 1 0 100-2v-3a1 1 0 00-1-1H9z" clipRule="evenodd" />
|
||||
</svg>
|
||||
</div>
|
||||
<div className="ml-3">
|
||||
<h3 className="text-sm font-medium text-blue-800">
|
||||
Task Scheduling & Monitoring
|
||||
</h3>
|
||||
<div className="mt-2 text-sm text-blue-700">
|
||||
<p>
|
||||
Automated tasks are now managed by Celery Beat. Monitor real-time task execution at{' '}
|
||||
<a
|
||||
href="http://localhost:5555"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="font-medium underline hover:text-blue-900"
|
||||
>
|
||||
Flower Dashboard
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-5 gap-6">
|
||||
<div className="bg-white p-6 rounded-lg shadow">
|
||||
<h3 className="text-lg font-medium text-gray-900">Total CVEs</h3>
|
||||
|
@ -1183,268 +1216,8 @@ function App() {
|
|||
</div>
|
||||
);
|
||||
|
||||
const SchedulerManager = () => {
|
||||
const [selectedJob, setSelectedJob] = useState(null);
|
||||
const [newSchedule, setNewSchedule] = useState('');
|
||||
const [showScheduleEdit, setShowScheduleEdit] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
fetchSchedulerData();
|
||||
const interval = setInterval(fetchSchedulerData, 30000); // Refresh every 30 seconds
|
||||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
const controlScheduler = async (action) => {
|
||||
try {
|
||||
const response = await axios.post('http://localhost:8000/api/scheduler/control', {
|
||||
action: action
|
||||
});
|
||||
console.log('Scheduler control response:', response.data);
|
||||
fetchSchedulerData();
|
||||
} catch (error) {
|
||||
console.error('Error controlling scheduler:', error);
|
||||
alert('Error controlling scheduler: ' + (error.response?.data?.detail || error.message));
|
||||
}
|
||||
};
|
||||
|
||||
const controlJob = async (jobName, action) => {
|
||||
try {
|
||||
const response = await axios.post('http://localhost:8000/api/scheduler/job/control', {
|
||||
job_name: jobName,
|
||||
action: action
|
||||
});
|
||||
console.log('Job control response:', response.data);
|
||||
fetchSchedulerData();
|
||||
} catch (error) {
|
||||
console.error('Error controlling job:', error);
|
||||
alert('Error controlling job: ' + (error.response?.data?.detail || error.message));
|
||||
}
|
||||
};
|
||||
|
||||
const updateJobSchedule = async (jobName, schedule) => {
|
||||
try {
|
||||
const response = await axios.post('http://localhost:8000/api/scheduler/job/schedule', {
|
||||
job_name: jobName,
|
||||
schedule: schedule
|
||||
});
|
||||
console.log('Schedule update response:', response.data);
|
||||
setShowScheduleEdit(false);
|
||||
setNewSchedule('');
|
||||
setSelectedJob(null);
|
||||
fetchSchedulerData();
|
||||
} catch (error) {
|
||||
console.error('Error updating schedule:', error);
|
||||
alert('Error updating schedule: ' + (error.response?.data?.detail || error.message));
|
||||
}
|
||||
};
|
||||
|
||||
const formatNextRun = (dateString) => {
|
||||
if (!dateString) return 'Not scheduled';
|
||||
const date = new Date(dateString);
|
||||
return date.toLocaleString();
|
||||
};
|
||||
|
||||
const getStatusColor = (status) => {
|
||||
switch (status) {
|
||||
case true:
|
||||
case 'enabled':
|
||||
return 'text-green-600';
|
||||
case false:
|
||||
case 'disabled':
|
||||
return 'text-red-600';
|
||||
case 'running':
|
||||
return 'text-blue-600';
|
||||
default:
|
||||
return 'text-gray-600';
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Scheduler Status */}
|
||||
<div className="bg-white rounded-lg shadow">
|
||||
<div className="px-6 py-4 border-b border-gray-200">
|
||||
<h2 className="text-xl font-bold text-gray-900">Job Scheduler Status</h2>
|
||||
</div>
|
||||
<div className="p-6">
|
||||
<div className="grid grid-cols-1 md:grid-cols-4 gap-4 mb-6">
|
||||
<div className="bg-gray-50 p-4 rounded-lg">
|
||||
<h3 className="text-sm font-medium text-gray-500">Scheduler Status</h3>
|
||||
<p className={`text-2xl font-bold ${getStatusColor(schedulerStatus.scheduler_running)}`}>
|
||||
{schedulerStatus.scheduler_running ? 'Running' : 'Stopped'}
|
||||
</p>
|
||||
</div>
|
||||
<div className="bg-gray-50 p-4 rounded-lg">
|
||||
<h3 className="text-sm font-medium text-gray-500">Total Jobs</h3>
|
||||
<p className="text-2xl font-bold text-gray-900">{schedulerStatus.total_jobs || 0}</p>
|
||||
</div>
|
||||
<div className="bg-gray-50 p-4 rounded-lg">
|
||||
<h3 className="text-sm font-medium text-gray-500">Enabled Jobs</h3>
|
||||
<p className="text-2xl font-bold text-green-600">{schedulerStatus.enabled_jobs || 0}</p>
|
||||
</div>
|
||||
<div className="bg-gray-50 p-4 rounded-lg">
|
||||
<h3 className="text-sm font-medium text-gray-500">Running Jobs</h3>
|
||||
<p className="text-2xl font-bold text-blue-600">{schedulerStatus.running_jobs || 0}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex space-x-4">
|
||||
<button
|
||||
onClick={() => controlScheduler('start')}
|
||||
disabled={schedulerStatus.scheduler_running}
|
||||
className={`px-4 py-2 rounded-md text-white ${
|
||||
schedulerStatus.scheduler_running
|
||||
? 'bg-gray-400 cursor-not-allowed'
|
||||
: 'bg-green-600 hover:bg-green-700'
|
||||
}`}
|
||||
>
|
||||
Start Scheduler
|
||||
</button>
|
||||
<button
|
||||
onClick={() => controlScheduler('stop')}
|
||||
disabled={!schedulerStatus.scheduler_running}
|
||||
className={`px-4 py-2 rounded-md text-white ${
|
||||
!schedulerStatus.scheduler_running
|
||||
? 'bg-gray-400 cursor-not-allowed'
|
||||
: 'bg-red-600 hover:bg-red-700'
|
||||
}`}
|
||||
>
|
||||
Stop Scheduler
|
||||
</button>
|
||||
<button
|
||||
onClick={() => controlScheduler('restart')}
|
||||
className="px-4 py-2 rounded-md text-white bg-blue-600 hover:bg-blue-700"
|
||||
>
|
||||
Restart Scheduler
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Scheduled Jobs */}
|
||||
<div className="bg-white rounded-lg shadow">
|
||||
<div className="px-6 py-4 border-b border-gray-200">
|
||||
<h2 className="text-xl font-bold text-gray-900">Scheduled Jobs</h2>
|
||||
</div>
|
||||
<div className="divide-y divide-gray-200">
|
||||
{Object.entries(schedulerJobs).map(([jobName, job]) => (
|
||||
<div key={jobName} className="p-6">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center space-x-3">
|
||||
<h3 className="text-lg font-medium text-gray-900">{jobName}</h3>
|
||||
<span className={`px-2 py-1 text-xs font-medium rounded-full ${
|
||||
job.enabled ? 'bg-green-100 text-green-800' : 'bg-red-100 text-red-800'
|
||||
}`}>
|
||||
{job.enabled ? 'Enabled' : 'Disabled'}
|
||||
</span>
|
||||
{job.is_running && (
|
||||
<span className="px-2 py-1 text-xs font-medium rounded-full bg-blue-100 text-blue-800">
|
||||
Running
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<p className="text-sm text-gray-600 mt-1">{job.description}</p>
|
||||
<div className="mt-2 grid grid-cols-2 md:grid-cols-4 gap-4 text-sm text-gray-600">
|
||||
<div>
|
||||
<span className="font-medium">Schedule:</span> {job.schedule}
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Next Run:</span> {formatNextRun(job.next_run)}
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Run Count:</span> {job.run_count}
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Failures:</span> {job.failure_count}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex space-x-2 ml-4">
|
||||
<button
|
||||
onClick={() => controlJob(jobName, job.enabled ? 'disable' : 'enable')}
|
||||
className={`px-3 py-1 rounded-md text-sm font-medium ${
|
||||
job.enabled
|
||||
? 'bg-red-100 text-red-700 hover:bg-red-200'
|
||||
: 'bg-green-100 text-green-700 hover:bg-green-200'
|
||||
}`}
|
||||
>
|
||||
{job.enabled ? 'Disable' : 'Enable'}
|
||||
</button>
|
||||
<button
|
||||
onClick={() => controlJob(jobName, 'trigger')}
|
||||
disabled={job.is_running}
|
||||
className={`px-3 py-1 rounded-md text-sm font-medium ${
|
||||
job.is_running
|
||||
? 'bg-gray-100 text-gray-500 cursor-not-allowed'
|
||||
: 'bg-blue-100 text-blue-700 hover:bg-blue-200'
|
||||
}`}
|
||||
>
|
||||
Run Now
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setSelectedJob(jobName);
|
||||
setNewSchedule(job.schedule);
|
||||
setShowScheduleEdit(true);
|
||||
}}
|
||||
className="px-3 py-1 rounded-md text-sm font-medium bg-gray-100 text-gray-700 hover:bg-gray-200"
|
||||
>
|
||||
Edit Schedule
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Schedule Edit Modal */}
|
||||
{showScheduleEdit && selectedJob && (
|
||||
<div className="fixed inset-0 bg-gray-600 bg-opacity-50 flex items-center justify-center z-50">
|
||||
<div className="bg-white rounded-lg p-6 w-full max-w-md">
|
||||
<h3 className="text-lg font-medium text-gray-900 mb-4">
|
||||
Edit Schedule for {selectedJob}
|
||||
</h3>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Cron Expression
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={newSchedule}
|
||||
onChange={(e) => setNewSchedule(e.target.value)}
|
||||
className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
placeholder="0 */6 * * *"
|
||||
/>
|
||||
<p className="text-xs text-gray-500 mt-1">
|
||||
Format: minute hour day_of_month month day_of_week
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex space-x-3">
|
||||
<button
|
||||
onClick={() => updateJobSchedule(selectedJob, newSchedule)}
|
||||
className="flex-1 bg-blue-600 text-white px-4 py-2 rounded-md hover:bg-blue-700"
|
||||
>
|
||||
Update
|
||||
</button>
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowScheduleEdit(false);
|
||||
setSelectedJob(null);
|
||||
setNewSchedule('');
|
||||
}}
|
||||
className="flex-1 bg-gray-300 text-gray-700 px-4 py-2 rounded-md hover:bg-gray-400"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
// Note: SchedulerManager component removed - job scheduling now handled by Celery Beat
|
||||
// Task monitoring available via Flower dashboard at http://localhost:5555
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
|
@ -1507,16 +1280,6 @@ function App() {
|
|||
>
|
||||
Bulk Jobs
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setActiveTab('scheduler')}
|
||||
className={`inline-flex items-center px-1 pt-1 border-b-2 text-sm font-medium ${
|
||||
activeTab === 'scheduler'
|
||||
? 'border-blue-500 text-gray-900'
|
||||
: 'border-transparent text-gray-500 hover:text-gray-700 hover:border-gray-300'
|
||||
}`}
|
||||
>
|
||||
Scheduler
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -1529,7 +1292,6 @@ function App() {
|
|||
{activeTab === 'cves' && <CVEList />}
|
||||
{activeTab === 'rules' && <SigmaRulesList />}
|
||||
{activeTab === 'bulk-jobs' && <BulkJobsList />}
|
||||
{activeTab === 'scheduler' && <SchedulerManager />}
|
||||
</div>
|
||||
</main>
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue