auto_sigma_rule_generator/backend/main.py
bpmcdevitt 9bde1395bf Optimize performance and migrate to Celery-based scheduling
This commit introduces major performance improvements and migrates from custom job scheduling to Celery Beat for better reliability and scalability.

### 🚀 Performance Optimizations

**CVE2CAPEC Client Performance (Fixed startup blocking)**
- Implement lazy loading with 24-hour cache for CVE2CAPEC mappings
- Add background task for CVE2CAPEC sync (data_sync_tasks.sync_cve2capec)
- Remove blocking data fetch during client initialization
- API endpoint: POST /api/sync-cve2capec

**ExploitDB Client Performance (Fixed webapp request blocking)**
- Implement global file index cache to prevent rebuilding on every request
- Add lazy loading with 24-hour cache expiry for 46K+ exploit index
- Background task for index building (data_sync_tasks.build_exploitdb_index)
- API endpoint: POST /api/build-exploitdb-index

### 🔄 Celery Migration & Scheduling

**Celery Beat Integration**
- Migrate from custom job scheduler to Celery Beat for reliability
- Remove 'finetuned' LLM provider (logic moved to ollama container)
- Optimized daily workflow with proper timing and dependencies

**New Celery Tasks Structure**
- tasks/bulk_tasks.py - NVD bulk processing and SIGMA generation
- tasks/data_sync_tasks.py - All data synchronization tasks
- tasks/maintenance_tasks.py - System maintenance and cleanup
- tasks/sigma_tasks.py - SIGMA rule generation tasks

**Daily Schedule (Optimized)**
```
1:00 AM  → Weekly cleanup (Sundays)
1:30 AM  → Daily result cleanup
2:00 AM  → NVD incremental update
3:00 AM  → CISA KEV sync
3:15 AM  → Nomi-sec PoC sync
3:30 AM  → GitHub PoC sync
3:45 AM  → ExploitDB sync
4:00 AM  → CVE2CAPEC MITRE ATT&CK sync
4:15 AM  → ExploitDB index rebuild
5:00 AM  → Reference content sync
8:00 AM  → SIGMA rule generation
9:00 AM  → LLM-enhanced SIGMA generation
Every 15min → Health checks
```

### 🐳 Docker & Infrastructure

**Enhanced Docker Setup**
- Ollama setup with integrated SIGMA model creation (setup_ollama_with_sigma.py)
- Initial database population check and trigger (initial_setup.py)
- Proper service dependencies and health checks
- Remove manual post-rebuild script requirements

**Service Architecture**
- Celery worker with 4-queue system (default, bulk_processing, sigma_generation, data_sync)
- Flower monitoring dashboard (localhost:5555)
- Redis as message broker and result backend

### 🎯 API Improvements

**Background Task Endpoints**
- GitHub PoC sync now uses Celery (was blocking backend)
- All sync operations return task IDs and monitoring URLs
- Consistent error handling and progress tracking

**New Endpoints**
- POST /api/sync-cve2capec - CVE2CAPEC mapping sync
- POST /api/build-exploitdb-index - ExploitDB index rebuild

### 📁 Cleanup

**Removed Files**
- fix_sigma_model.sh (replaced by setup_ollama_with_sigma.py)
- Various test_* and debug_* files no longer needed
- Old training scripts related to removed 'finetuned' provider
- Utility scripts replaced by Docker services

### 🔧 Configuration

**Key Files Added/Modified**
- backend/celery_config.py - Complete Celery configuration
- backend/initial_setup.py - First-boot database population
- backend/setup_ollama_with_sigma.py - Integrated Ollama setup
- CLAUDE.md - Project documentation and development guide

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-17 18:58:47 -05:00

2161 lines
86 KiB
Python

from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.dialects.postgresql import UUID
import uuid
from datetime import datetime, timedelta
import requests
import json
import re
import os
from typing import List, Optional
from pydantic import BaseModel
import asyncio
from contextlib import asynccontextmanager
import base64
from github import Github
from urllib.parse import urlparse
import hashlib
import logging
import threading
from mcdevitt_poc_client import GitHubPoCClient
from cve2capec_client import CVE2CAPECClient
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global job tracking
running_jobs = {}
job_cancellation_flags = {}
# Database setup
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Database Models
class CVE(Base):
__tablename__ = "cves"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20), unique=True, nullable=False)
description = Column(Text)
cvss_score = Column(DECIMAL(3, 1))
severity = Column(String(20))
published_date = Column(TIMESTAMP)
modified_date = Column(TIMESTAMP)
affected_products = Column(ARRAY(String))
reference_urls = Column(ARRAY(String))
# Bulk processing fields
data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual'
nvd_json_version = Column(String(10), default='2.0')
bulk_processed = Column(Boolean, default=False)
# nomi-sec PoC fields
poc_count = Column(Integer, default=0)
poc_data = Column(JSON) # Store nomi-sec PoC metadata
# Reference data fields
reference_data = Column(JSON) # Store extracted reference content and analysis
reference_sync_status = Column(String(20), default='pending') # 'pending', 'processing', 'completed', 'failed'
reference_last_synced = Column(TIMESTAMP)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class SigmaRule(Base):
__tablename__ = "sigma_rules"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
cve_id = Column(String(20))
rule_name = Column(String(255), nullable=False)
rule_content = Column(Text, nullable=False)
detection_type = Column(String(50))
log_source = Column(String(100))
confidence_level = Column(String(20))
auto_generated = Column(Boolean, default=True)
exploit_based = Column(Boolean, default=False)
github_repos = Column(ARRAY(String))
exploit_indicators = Column(Text) # JSON string of extracted indicators
# Enhanced fields for new data sources
poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual'
poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc.
nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata
created_at = Column(TIMESTAMP, default=datetime.utcnow)
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
class RuleTemplate(Base):
__tablename__ = "rule_templates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
template_name = Column(String(255), nullable=False)
template_content = Column(Text, nullable=False)
applicable_product_patterns = Column(ARRAY(String))
description = Column(Text)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
class BulkProcessingJob(Base):
__tablename__ = "bulk_processing_jobs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update'
status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled'
year = Column(Integer) # For year-based processing
total_items = Column(Integer, default=0)
processed_items = Column(Integer, default=0)
failed_items = Column(Integer, default=0)
error_message = Column(Text)
job_metadata = Column(JSON) # Additional job-specific data
started_at = Column(TIMESTAMP)
completed_at = Column(TIMESTAMP)
cancelled_at = Column(TIMESTAMP)
created_at = Column(TIMESTAMP, default=datetime.utcnow)
# Pydantic models
class CVEResponse(BaseModel):
id: str
cve_id: str
description: Optional[str] = None
cvss_score: Optional[float] = None
severity: Optional[str] = None
published_date: Optional[datetime] = None
affected_products: Optional[List[str]] = None
reference_urls: Optional[List[str]] = None
class Config:
from_attributes = True
class SigmaRuleResponse(BaseModel):
id: str
cve_id: str
rule_name: str
rule_content: str
detection_type: Optional[str] = None
log_source: Optional[str] = None
confidence_level: Optional[str] = None
auto_generated: bool = True
exploit_based: bool = False
github_repos: Optional[List[str]] = None
exploit_indicators: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
# Request models
class BulkSeedRequest(BaseModel):
start_year: int = 2002
end_year: Optional[int] = None
skip_nvd: bool = False
skip_nomi_sec: bool = True
class NomiSecSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 50
class GitHubPoCSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 50
class ExploitDBSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 30
class CISAKEVSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 100
class ReferenceSyncRequest(BaseModel):
cve_id: Optional[str] = None
batch_size: int = 30
max_cves: Optional[int] = None
force_resync: bool = False
class RuleRegenRequest(BaseModel):
force: bool = False
# GitHub Exploit Analysis Service
class GitHubExploitAnalyzer:
def __init__(self):
self.github_token = os.getenv("GITHUB_TOKEN")
self.github = Github(self.github_token) if self.github_token else None
async def search_exploits_for_cve(self, cve_id: str) -> List[dict]:
"""Search GitHub for exploit code related to a CVE"""
if not self.github:
print(f"No GitHub token configured, skipping exploit search for {cve_id}")
return []
try:
print(f"Searching GitHub for exploits for {cve_id}")
# Search queries to find exploit code
search_queries = [
f"{cve_id} exploit",
f"{cve_id} poc",
f"{cve_id} vulnerability",
f'"{cve_id}" exploit code',
f"{cve_id.replace('-', '_')} exploit"
]
exploits = []
seen_repos = set()
for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits
try:
# Search repositories
repos = self.github.search_repositories(
query=query,
sort="updated",
order="desc"
)
# Get top 5 results per query
for repo in repos[:5]:
if repo.full_name in seen_repos:
continue
seen_repos.add(repo.full_name)
# Analyze repository
exploit_info = await self._analyze_repository(repo, cve_id)
if exploit_info:
exploits.append(exploit_info)
if len(exploits) >= 10: # Limit total exploits
break
if len(exploits) >= 10:
break
except Exception as e:
print(f"Error searching GitHub with query '{query}': {str(e)}")
continue
print(f"Found {len(exploits)} potential exploits for {cve_id}")
return exploits
except Exception as e:
print(f"Error searching GitHub for {cve_id}: {str(e)}")
return []
async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]:
"""Analyze a GitHub repository for exploit code"""
try:
# Check if repo name or description mentions the CVE
repo_text = f"{repo.name} {repo.description or ''}".lower()
if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text:
return None
# Get repository contents
exploit_files = []
indicators = {
'processes': set(),
'files': set(),
'registry': set(),
'network': set(),
'commands': set(),
'powershell': set(),
'urls': set()
}
try:
contents = repo.get_contents("")
for content in contents[:20]: # Limit files to analyze
if content.type == "file" and self._is_exploit_file(content.name):
file_analysis = await self._analyze_file_content(repo, content, cve_id)
if file_analysis:
exploit_files.append(file_analysis)
# Merge indicators
for key, values in file_analysis.get('indicators', {}).items():
if key in indicators:
indicators[key].update(values)
except Exception as e:
print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}")
if not exploit_files:
return None
return {
'repo_name': repo.full_name,
'repo_url': repo.html_url,
'description': repo.description,
'language': repo.language,
'stars': repo.stargazers_count,
'updated': repo.updated_at.isoformat(),
'files': exploit_files,
'indicators': {k: list(v) for k, v in indicators.items()}
}
except Exception as e:
print(f"Error analyzing repository {repo.full_name}: {str(e)}")
return None
def _is_exploit_file(self, filename: str) -> bool:
"""Check if a file is likely to contain exploit code"""
exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java']
exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack']
filename_lower = filename.lower()
# Check extension
if not any(filename_lower.endswith(ext) for ext in exploit_extensions):
return False
# Check filename for exploit-related terms
return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower
async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]:
"""Analyze individual file content for exploit indicators"""
try:
if file_content.size > 100000: # Skip files larger than 100KB
return None
# Decode file content
content = file_content.decoded_content.decode('utf-8', errors='ignore')
# Check if file actually mentions the CVE
if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower():
return None
indicators = self._extract_indicators_from_code(content, file_content.name)
if not any(indicators.values()):
return None
return {
'filename': file_content.name,
'path': file_content.path,
'size': file_content.size,
'indicators': indicators
}
except Exception as e:
print(f"Error analyzing file {file_content.name}: {str(e)}")
return None
def _extract_indicators_from_code(self, content: str, filename: str) -> dict:
"""Extract security indicators from exploit code"""
indicators = {
'processes': set(),
'files': set(),
'registry': set(),
'network': set(),
'commands': set(),
'powershell': set(),
'urls': set()
}
# Process patterns
process_patterns = [
r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']',
r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']',
r'system\s*\(\s*["\']([^"\']+)["\']',
r'exec\s*\(\s*["\']([^"\']+)["\']',
r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']'
]
# File patterns
file_patterns = [
r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']',
r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?',
r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)',
r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)'
]
# Registry patterns
registry_patterns = [
r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']',
r'HKEY_[A-Z_]+\\\\([^"\'\\]+)',
r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?'
]
# Network patterns
network_patterns = [
r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)',
r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)',
r'(?:http|https|ftp)://([^\s"\'<>]+)',
r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)'
]
# PowerShell patterns
powershell_patterns = [
r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?',
r'Start-Process\s+["\']?([^"\']+)["\']?',
r'Get-Process\s+["\']?([^"\']+)["\']?'
]
# Command patterns
command_patterns = [
r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)',
r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)'
]
# Extract indicators using regex patterns
patterns = {
'processes': process_patterns,
'files': file_patterns,
'registry': registry_patterns,
'powershell': powershell_patterns,
'commands': command_patterns
}
for category, pattern_list in patterns.items():
for pattern in pattern_list:
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
for match in matches:
if isinstance(match, tuple):
indicators[category].add(match[0])
else:
indicators[category].add(match)
# Special handling for network indicators
for pattern in network_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
if len(match) >= 2:
indicators['network'].add(f"{match[0]}:{match[1]}")
else:
indicators['network'].add(match[0])
else:
indicators['network'].add(match)
# Convert sets to lists and filter out empty/invalid indicators
cleaned_indicators = {}
for key, values in indicators.items():
cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200]
if cleaned_values:
cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category
return cleaned_indicators
class CVESigmaService:
def __init__(self, db: Session):
self.db = db
self.nvd_api_key = os.getenv("NVD_API_KEY")
async def fetch_recent_cves(self, days_back: int = 7):
"""Fetch recent CVEs from NVD API"""
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=days_back)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
params = {
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
"resultsPerPage": 100
}
headers = {}
if self.nvd_api_key:
headers["apiKey"] = self.nvd_api_key
try:
response = requests.get(url, params=params, headers=headers, timeout=30)
response.raise_for_status()
data = response.json()
new_cves = []
for vuln in data.get("vulnerabilities", []):
cve_data = vuln.get("cve", {})
cve_id = cve_data.get("id")
# Check if CVE already exists
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
if existing:
continue
# Extract CVE information
description = ""
if cve_data.get("descriptions"):
description = cve_data["descriptions"][0].get("value", "")
cvss_score = None
severity = None
if cve_data.get("metrics", {}).get("cvssMetricV31"):
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
affected_products = []
if cve_data.get("configurations"):
for config in cve_data["configurations"]:
for node in config.get("nodes", []):
for cpe_match in node.get("cpeMatch", []):
if cpe_match.get("vulnerable"):
affected_products.append(cpe_match.get("criteria", ""))
reference_urls = []
if cve_data.get("references"):
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
cve_obj = CVE(
cve_id=cve_id,
description=description,
cvss_score=cvss_score,
severity=severity,
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
affected_products=affected_products,
reference_urls=reference_urls
)
self.db.add(cve_obj)
new_cves.append(cve_obj)
self.db.commit()
return new_cves
except Exception as e:
print(f"Error fetching CVEs: {str(e)}")
return []
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
"""Generate SIGMA rule based on CVE data"""
if not cve.description:
return None
# Analyze CVE to determine appropriate template
description_lower = cve.description.lower()
affected_products = [p.lower() for p in (cve.affected_products or [])]
template = self._select_template(description_lower, affected_products)
if not template:
return None
# Generate rule content
rule_content = self._populate_template(cve, template)
if not rule_content:
return None
# Determine detection type and confidence
detection_type = self._determine_detection_type(description_lower)
confidence_level = self._calculate_confidence(cve)
sigma_rule = SigmaRule(
cve_id=cve.cve_id,
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
rule_content=rule_content,
detection_type=detection_type,
log_source=template.template_name.lower().replace(" ", "_"),
confidence_level=confidence_level,
auto_generated=True
)
self.db.add(sigma_rule)
return sigma_rule
def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None):
"""Select appropriate SIGMA rule template based on CVE and exploit analysis"""
templates = self.db.query(RuleTemplate).all()
# If we have exploit indicators, use them to determine the best template
if exploit_indicators:
if exploit_indicators.get('powershell'):
powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None)
if powershell_template:
return powershell_template
if exploit_indicators.get('network'):
network_template = next((t for t in templates if "Network Connection" in t.template_name), None)
if network_template:
return network_template
if exploit_indicators.get('files'):
file_template = next((t for t in templates if "File Modification" in t.template_name), None)
if file_template:
return file_template
if exploit_indicators.get('processes') or exploit_indicators.get('commands'):
process_template = next((t for t in templates if "Process Execution" in t.template_name), None)
if process_template:
return process_template
# Fallback to original logic
if any("windows" in p or "microsoft" in p for p in affected_products):
if "process" in description or "execution" in description:
return next((t for t in templates if "Process Execution" in t.template_name), None)
elif "network" in description or "remote" in description:
return next((t for t in templates if "Network Connection" in t.template_name), None)
elif "file" in description or "write" in description:
return next((t for t in templates if "File Modification" in t.template_name), None)
# Default to process execution template
return next((t for t in templates if "Process Execution" in t.template_name), None)
def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str:
"""Populate template with CVE-specific data and exploit indicators"""
try:
# Use exploit indicators if available, otherwise extract from description
if exploit_indicators:
suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', [])
suspicious_ports = []
file_patterns = exploit_indicators.get('files', [])
# Extract ports from network indicators
for net_indicator in exploit_indicators.get('network', []):
if ':' in str(net_indicator):
try:
port = int(str(net_indicator).split(':')[-1])
suspicious_ports.append(port)
except ValueError:
pass
else:
# Fallback to original extraction
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
# Determine severity level
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
# Create enhanced description
enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description
if exploit_indicators:
enhanced_description += " [Enhanced with GitHub exploit analysis]"
# Build tags
tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()]
if exploit_indicators:
tags.append("exploit.github")
rule_content = template.template_content.format(
title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection",
description=enhanced_description,
rule_id=str(uuid.uuid4()),
date=datetime.utcnow().strftime("%Y/%m/%d"),
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
cve_id=cve.cve_id.lower(),
tags="\n - ".join(tags),
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
file_patterns=file_patterns or ["temp", "malware", "exploit"],
level=level
)
return rule_content
except Exception as e:
print(f"Error populating template: {str(e)}")
return None
def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str:
"""Map CVE and exploit indicators to MITRE ATT&CK techniques"""
desc_lower = description.lower()
# Check exploit indicators first
if exploit_indicators:
if exploit_indicators.get('powershell'):
return "t1059.001" # PowerShell
elif exploit_indicators.get('commands'):
return "t1059.003" # Windows Command Shell
elif exploit_indicators.get('network'):
return "t1071.001" # Web Protocols
elif exploit_indicators.get('files'):
return "t1105" # Ingress Tool Transfer
elif exploit_indicators.get('processes'):
return "t1106" # Native API
# Fallback to description analysis
if "powershell" in desc_lower:
return "t1059.001"
elif "command" in desc_lower or "cmd" in desc_lower:
return "t1059.003"
elif "network" in desc_lower or "remote" in desc_lower:
return "t1071.001"
elif "file" in desc_lower or "upload" in desc_lower:
return "t1105"
elif "process" in desc_lower or "execution" in desc_lower:
return "t1106"
else:
return "execution" # Generic
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
"""Extract suspicious indicators from CVE description"""
if indicator_type == "process":
# Look for executable names or process patterns
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
return exe_pattern[:5] if exe_pattern else None
elif indicator_type == "port":
# Look for port numbers
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
return [int(p) for p in port_pattern[:3]] if port_pattern else None
elif indicator_type == "file":
# Look for file extensions or paths
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
return file_pattern[:5] if file_pattern else None
return None
def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str:
"""Determine detection type based on CVE description and exploit indicators"""
if exploit_indicators:
if exploit_indicators.get('powershell'):
return "powershell"
elif exploit_indicators.get('network'):
return "network"
elif exploit_indicators.get('files'):
return "file"
elif exploit_indicators.get('processes') or exploit_indicators.get('commands'):
return "process"
# Fallback to original logic
if "remote" in description or "network" in description:
return "network"
elif "process" in description or "execution" in description:
return "process"
elif "file" in description or "filesystem" in description:
return "file"
else:
return "general"
def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str:
"""Calculate confidence level for the generated rule"""
base_confidence = 0
# CVSS score contributes to confidence
if cve.cvss_score:
if cve.cvss_score >= 9.0:
base_confidence += 3
elif cve.cvss_score >= 7.0:
base_confidence += 2
else:
base_confidence += 1
# Exploit-based rules get higher confidence
if exploit_based:
base_confidence += 2
# Map to confidence levels
if base_confidence >= 4:
return "high"
elif base_confidence >= 2:
return "medium"
else:
return "low"
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Background task to fetch CVEs and generate rules
async def background_cve_fetch():
retry_count = 0
max_retries = 3
while True:
try:
db = SessionLocal()
service = CVESigmaService(db)
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
print(f"[{current_time}] Starting CVE fetch cycle...")
# Use a longer initial period (30 days) to find CVEs
new_cves = await service.fetch_recent_cves(days_back=30)
if new_cves:
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
rules_generated = 0
for cve in new_cves:
try:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
print(f"Generated SIGMA rule for {cve.cve_id}")
else:
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
except Exception as e:
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
db.commit()
print(f"Successfully generated {rules_generated} SIGMA rules")
retry_count = 0 # Reset retry count on success
else:
print("No new CVEs found in this cycle")
# After first successful run, reduce to 7 days for regular updates
if retry_count == 0:
print("Switching to 7-day lookback for future runs...")
db.close()
except Exception as e:
retry_count += 1
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
if retry_count >= max_retries:
print(f"Max retries reached, waiting longer before next attempt...")
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
retry_count = 0
else:
await asyncio.sleep(300) # Wait 5 minutes before retry
continue
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
wait_time = 3600 if retry_count == 0 else 1800
print(f"Next CVE fetch in {wait_time//60} minutes...")
await asyncio.sleep(wait_time)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize database
Base.metadata.create_all(bind=engine)
# Initialize rule templates
db = SessionLocal()
try:
existing_templates = db.query(RuleTemplate).count()
if existing_templates == 0:
logger.info("No rule templates found. Database initialization will handle template creation.")
except Exception as e:
logger.error(f"Error checking rule templates: {e}")
finally:
db.close()
# Note: Job scheduling is now handled by Celery Beat
# All scheduled tasks are defined in celery_config.py
logger.info("Application startup complete - scheduled tasks handled by Celery Beat")
yield
# Shutdown
logger.info("Application shutdown complete")
# FastAPI app
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include Celery job management routes
try:
from routers.celery_jobs import router as celery_router
app.include_router(celery_router, prefix="/api")
logger.info("Celery job routes loaded successfully")
except ImportError as e:
logger.warning(f"Celery job routes not available: {e}")
except Exception as e:
logger.error(f"Error loading Celery job routes: {e}")
@app.get("/api/cves", response_model=List[CVEResponse])
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each CVE
result = []
for cve in cves:
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
result.append(CVEResponse(**cve_dict))
return result
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
raise HTTPException(status_code=404, detail="CVE not found")
cve_dict = {
'id': str(cve.id),
'cve_id': cve.cve_id,
'description': cve.description,
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
'severity': cve.severity,
'published_date': cve.published_date,
'affected_products': cve.affected_products,
'reference_urls': cve.reference_urls
}
return CVEResponse(**cve_dict)
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'exploit_based': rule.exploit_based or False,
'github_repos': rule.github_repos or [],
'exploit_indicators': rule.exploit_indicators,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
# Convert UUID to string for each rule
result = []
for rule in rules:
rule_dict = {
'id': str(rule.id),
'cve_id': rule.cve_id,
'rule_name': rule.rule_name,
'rule_content': rule.rule_content,
'detection_type': rule.detection_type,
'log_source': rule.log_source,
'confidence_level': rule.confidence_level,
'auto_generated': rule.auto_generated,
'exploit_based': rule.exploit_based or False,
'github_repos': rule.github_repos or [],
'exploit_indicators': rule.exploit_indicators,
'created_at': rule.created_at
}
result.append(SigmaRuleResponse(**rule_dict))
return result
@app.post("/api/fetch-cves")
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
async def fetch_task():
try:
service = CVESigmaService(db)
print("Manual CVE fetch initiated...")
# Use 30 days for manual fetch to get more results
new_cves = await service.fetch_recent_cves(days_back=30)
rules_generated = 0
for cve in new_cves:
sigma_rule = service.generate_sigma_rule(cve)
if sigma_rule:
rules_generated += 1
db.commit()
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
except Exception as e:
print(f"Manual fetch error: {str(e)}")
import traceback
traceback.print_exc()
background_tasks.add_task(fetch_task)
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
@app.get("/api/test-nvd")
async def test_nvd_connection():
"""Test endpoint to check NVD API connectivity"""
try:
# Test with a simple request using current date
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=30)
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
params = {
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
"resultsPerPage": 5,
"startIndex": 0
}
headers = {
"User-Agent": "CVE-SIGMA-Generator/1.0",
"Accept": "application/json"
}
nvd_api_key = os.getenv("NVD_API_KEY")
if nvd_api_key:
headers["apiKey"] = nvd_api_key
print(f"Testing NVD API with URL: {url}")
print(f"Test params: {params}")
print(f"Test headers: {headers}")
response = requests.get(url, params=params, headers=headers, timeout=15)
result = {
"status": "success" if response.status_code == 200 else "error",
"status_code": response.status_code,
"has_api_key": bool(nvd_api_key),
"request_url": f"{url}?{requests.compat.urlencode(params)}",
"response_headers": dict(response.headers)
}
if response.status_code == 200:
data = response.json()
result.update({
"total_results": data.get("totalResults", 0),
"results_per_page": data.get("resultsPerPage", 0),
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
"message": "NVD API is accessible and returning data"
})
else:
result.update({
"error_message": response.text[:200],
"message": f"NVD API returned {response.status_code}"
})
# Try fallback without date filters if we get 404
if response.status_code == 404:
print("Trying fallback without date filters...")
fallback_params = {
"resultsPerPage": 5,
"startIndex": 0
}
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
result["fallback_status_code"] = fallback_response.status_code
if fallback_response.status_code == 200:
fallback_data = fallback_response.json()
result.update({
"fallback_success": True,
"fallback_total_results": fallback_data.get("totalResults", 0),
"message": "NVD API works without date filters"
})
return result
except Exception as e:
print(f"NVD API test error: {str(e)}")
return {
"status": "error",
"message": f"Failed to connect to NVD API: {str(e)}"
}
@app.get("/api/stats")
async def get_stats(db: Session = Depends(get_db)):
total_cves = db.query(CVE).count()
total_rules = db.query(SigmaRule).count()
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
# Enhanced stats with bulk processing info
bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count()
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count()
nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count()
return {
"total_cves": total_cves,
"total_sigma_rules": total_rules,
"recent_cves_7_days": recent_cves,
"bulk_processed_cves": bulk_processed_cves,
"cves_with_pocs": cves_with_pocs,
"nomi_sec_rules": nomi_sec_rules,
"poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0,
"nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0
}
# New bulk processing endpoints
@app.post("/api/bulk-seed")
async def start_bulk_seed(request: BulkSeedRequest):
"""Start bulk seeding process - redirects to async endpoint"""
try:
from routers.celery_jobs import start_bulk_seed as async_bulk_seed
from routers.celery_jobs import BulkSeedRequest as CeleryBulkSeedRequest
# Convert request to Celery format
celery_request = CeleryBulkSeedRequest(
start_year=request.start_year,
end_year=request.end_year,
skip_nvd=request.skip_nvd,
skip_nomi_sec=request.skip_nomi_sec,
skip_exploitdb=getattr(request, 'skip_exploitdb', False),
skip_cisa_kev=getattr(request, 'skip_cisa_kev', False)
)
# Call async endpoint
result = await async_bulk_seed(celery_request)
return {
"message": "Bulk seeding process started (async)",
"status": "started",
"task_id": result.task_id,
"start_year": request.start_year,
"end_year": request.end_year or datetime.now().year,
"skip_nvd": request.skip_nvd,
"skip_nomi_sec": request.skip_nomi_sec,
"async_endpoint": f"/api/task-status/{result.task_id}"
}
except Exception as e:
logger.error(f"Error starting bulk seed: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start bulk seed: {e}")
@app.post("/api/incremental-update")
async def start_incremental_update():
"""Start incremental update process - redirects to async endpoint"""
try:
from routers.celery_jobs import start_incremental_update as async_incremental_update
# Call async endpoint
result = await async_incremental_update()
return {
"message": "Incremental update process started (async)",
"status": "started",
"task_id": result.task_id,
"async_endpoint": f"/api/task-status/{result.task_id}"
}
except Exception as e:
logger.error(f"Error starting incremental update: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start incremental update: {e}")
@app.post("/api/sync-nomi-sec")
async def sync_nomi_sec(request: NomiSecSyncRequest):
"""Synchronize nomi-sec PoC data - redirects to async endpoint"""
try:
from routers.celery_jobs import start_nomi_sec_sync as async_nomi_sec_sync
from routers.celery_jobs import DataSyncRequest as CeleryDataSyncRequest
# Convert request to Celery format
celery_request = CeleryDataSyncRequest(
batch_size=request.batch_size
)
# Call async endpoint
result = await async_nomi_sec_sync(celery_request)
return {
"message": f"Nomi-sec sync started (async)" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"task_id": result.task_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size,
"async_endpoint": f"/api/task-status/{result.task_id}"
}
except Exception as e:
logger.error(f"Error starting nomi-sec sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start nomi-sec sync: {e}")
@app.post("/api/sync-github-pocs")
async def sync_github_pocs(request: GitHubPoCSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize GitHub PoC data using Celery task"""
try:
from celery_config import celery_app
from tasks.data_sync_tasks import sync_github_poc_task
# Launch Celery task
if request.cve_id:
# For specific CVE sync, we'll still use the general task
task_result = sync_github_poc_task.delay(batch_size=request.batch_size)
else:
# For bulk sync
task_result = sync_github_poc_task.delay(batch_size=request.batch_size)
return {
"message": f"GitHub PoC sync started via Celery" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"task_id": task_result.id,
"cve_id": request.cve_id,
"batch_size": request.batch_size,
"monitor_url": "http://localhost:5555/task/" + task_result.id
}
except Exception as e:
logger.error(f"Error starting GitHub PoC sync via Celery: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start GitHub PoC sync: {e}")
@app.post("/api/sync-exploitdb")
async def sync_exploitdb(background_tasks: BackgroundTasks,
request: ExploitDBSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize ExploitDB data from git mirror"""
# Create job record
job = BulkProcessingJob(
job_type='exploitdb_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(task_db)
if request.cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_exploits(request.cve_id)
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_exploitdb(
batch_size=request.batch_size,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
logger.info(f"ExploitDB bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"ExploitDB sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@app.post("/api/sync-cisa-kev")
async def sync_cisa_kev(background_tasks: BackgroundTasks,
request: CISAKEVSyncRequest,
db: Session = Depends(get_db)):
"""Synchronize CISA Known Exploited Vulnerabilities data"""
# Create job record
job = BulkProcessingJob(
job_type='cisa_kev_sync',
status='pending',
job_metadata={
'cve_id': request.cve_id,
'batch_size': request.batch_size
}
)
db.add(job)
db.commit()
db.refresh(job)
job_id = str(job.id)
running_jobs[job_id] = job
job_cancellation_flags[job_id] = False
async def sync_task():
# Create a new database session for the background task
task_db = SessionLocal()
try:
# Get the job in the new session
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if not task_job:
logger.error(f"Job {job_id} not found in task session")
return
task_job.status = 'running'
task_job.started_at = datetime.utcnow()
task_db.commit()
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(task_db)
if request.cve_id:
# Sync specific CVE
if job_cancellation_flags.get(job_id, False):
logger.info(f"Job {job_id} cancelled before starting")
return
result = await client.sync_cve_kev_data(request.cve_id)
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
else:
# Sync all CVEs with cancellation support
result = await client.bulk_sync_kev_data(
batch_size=request.batch_size,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
logger.info(f"CISA KEV bulk sync completed: {result}")
# Update job status if not cancelled
if not job_cancellation_flags.get(job_id, False):
task_job.status = 'completed'
task_job.completed_at = datetime.utcnow()
task_db.commit()
except Exception as e:
if not job_cancellation_flags.get(job_id, False):
# Get the job again in case it was modified
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
if task_job:
task_job.status = 'failed'
task_job.error_message = str(e)
task_job.completed_at = datetime.utcnow()
task_db.commit()
logger.error(f"CISA KEV sync failed: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up tracking and close the task session
running_jobs.pop(job_id, None)
job_cancellation_flags.pop(job_id, None)
task_db.close()
background_tasks.add_task(sync_task)
return {
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size
}
@app.post("/api/sync-references")
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Start reference data synchronization"""
try:
from reference_client import ReferenceClient
client = ReferenceClient(db)
# Create job ID
job_id = str(uuid.uuid4())
# Add job to tracking
running_jobs[job_id] = {
'type': 'reference_sync',
'status': 'running',
'cve_id': request.cve_id,
'batch_size': request.batch_size,
'max_cves': request.max_cves,
'force_resync': request.force_resync,
'started_at': datetime.utcnow()
}
# Create cancellation flag
job_cancellation_flags[job_id] = False
async def sync_task():
try:
if request.cve_id:
# Single CVE sync
result = await client.sync_cve_references(request.cve_id)
running_jobs[job_id]['result'] = result
running_jobs[job_id]['status'] = 'completed'
else:
# Bulk sync
result = await client.bulk_sync_references(
batch_size=request.batch_size,
max_cves=request.max_cves,
force_resync=request.force_resync,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
)
running_jobs[job_id]['result'] = result
running_jobs[job_id]['status'] = 'completed'
running_jobs[job_id]['completed_at'] = datetime.utcnow()
except Exception as e:
logger.error(f"Reference sync task failed: {e}")
running_jobs[job_id]['status'] = 'failed'
running_jobs[job_id]['error'] = str(e)
running_jobs[job_id]['completed_at'] = datetime.utcnow()
finally:
# Clean up cancellation flag
job_cancellation_flags.pop(job_id, None)
background_tasks.add_task(sync_task)
return {
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
"status": "started",
"job_id": job_id,
"cve_id": request.cve_id,
"batch_size": request.batch_size,
"max_cves": request.max_cves,
"force_resync": request.force_resync
}
except Exception as e:
logger.error(f"Failed to start reference sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
@app.get("/api/reference-stats")
async def get_reference_stats(db: Session = Depends(get_db)):
"""Get reference synchronization statistics"""
try:
from reference_client import ReferenceClient
client = ReferenceClient(db)
# Get sync status
status = await client.get_reference_sync_status()
# Get quality distribution from reference data
quality_distribution = {}
from sqlalchemy import text
cves_with_references = db.query(CVE).filter(
text("reference_data::text LIKE '%\"reference_analysis\"%'")
).all()
for cve in cves_with_references:
if cve.reference_data and 'reference_analysis' in cve.reference_data:
ref_analysis = cve.reference_data['reference_analysis']
high_conf_refs = ref_analysis.get('high_confidence_references', 0)
total_refs = ref_analysis.get('reference_count', 0)
if total_refs > 0:
quality_ratio = high_conf_refs / total_refs
if quality_ratio >= 0.8:
quality_tier = 'excellent'
elif quality_ratio >= 0.6:
quality_tier = 'good'
elif quality_ratio >= 0.4:
quality_tier = 'fair'
else:
quality_tier = 'poor'
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
# Get reference type distribution
reference_type_distribution = {}
for cve in cves_with_references:
if cve.reference_data and 'reference_analysis' in cve.reference_data:
ref_analysis = cve.reference_data['reference_analysis']
ref_types = ref_analysis.get('reference_types', [])
for ref_type in ref_types:
reference_type_distribution[ref_type] = reference_type_distribution.get(ref_type, 0) + 1
return {
'reference_sync_status': status,
'quality_distribution': quality_distribution,
'reference_type_distribution': reference_type_distribution,
'total_with_reference_analysis': len(cves_with_references),
'source': 'reference_extraction'
}
except Exception as e:
logger.error(f"Failed to get reference stats: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get reference stats: {str(e)}")
@app.get("/api/exploitdb-stats")
async def get_exploitdb_stats(db: Session = Depends(get_db)):
"""Get ExploitDB-related statistics"""
try:
from exploitdb_client_local import ExploitDBLocalClient
client = ExploitDBLocalClient(db)
# Get sync status
status = await client.get_exploitdb_sync_status()
# Get quality distribution from ExploitDB data
quality_distribution = {}
from sqlalchemy import text
cves_with_exploitdb = db.query(CVE).filter(
text("poc_data::text LIKE '%\"exploitdb\"%'")
).all()
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
# Get category distribution
category_distribution = {}
for cve in cves_with_exploitdb:
if cve.poc_data and 'exploitdb' in cve.poc_data:
exploits = cve.poc_data['exploitdb'].get('exploits', [])
for exploit in exploits:
category = exploit.get('category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
return {
"exploitdb_sync_status": status,
"quality_distribution": quality_distribution,
"category_distribution": category_distribution,
"total_exploitdb_cves": len(cves_with_exploitdb),
"total_exploits": sum(
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
for cve in cves_with_exploitdb
if cve.poc_data and 'exploitdb' in cve.poc_data
)
}
except Exception as e:
logger.error(f"Error getting ExploitDB stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/github-poc-stats")
async def get_github_poc_stats(db: Session = Depends(get_db)):
"""Get GitHub PoC-related statistics"""
try:
# Get basic statistics
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
cves_with_github_pocs = db.query(CVE).filter(
CVE.poc_data.isnot(None), # Check if poc_data exists
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).count()
# Get quality distribution
quality_distribution = {}
try:
quality_results = db.query(
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
func.count().label('count')
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).group_by('tier').all()
for tier, count in quality_results:
if tier:
quality_distribution[tier] = count
except Exception as e:
logger.warning(f"Error getting quality distribution: {e}")
quality_distribution = {}
# Calculate average quality score
try:
avg_quality = db.query(
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
).filter(
CVE.poc_data.isnot(None),
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
).scalar() or 0
except Exception as e:
logger.warning(f"Error calculating average quality: {e}")
avg_quality = 0
return {
'github_poc_rules': github_poc_rules,
'cves_with_github_pocs': cves_with_github_pocs,
'quality_distribution': quality_distribution,
'average_quality_score': float(avg_quality) if avg_quality else 0,
'source': 'github_poc'
}
except Exception as e:
logger.error(f"Error getting GitHub PoC stats: {e}")
return {"error": str(e)}
@app.get("/api/github-poc-status")
async def get_github_poc_status(db: Session = Depends(get_db)):
"""Get GitHub PoC data availability status"""
try:
client = GitHubPoCClient(db)
# Check if GitHub PoC data is available
github_poc_data = client.load_github_poc_data()
return {
'github_poc_data_available': len(github_poc_data) > 0,
'total_cves_with_pocs': len(github_poc_data),
'sample_cve_ids': list(github_poc_data.keys())[:10], # First 10 CVE IDs
'data_path': str(client.github_poc_path),
'path_exists': client.github_poc_path.exists()
}
except Exception as e:
logger.error(f"Error checking GitHub PoC status: {e}")
return {"error": str(e)}
@app.get("/api/cisa-kev-stats")
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
"""Get CISA KEV-related statistics"""
try:
from cisa_kev_client import CISAKEVClient
client = CISAKEVClient(db)
# Get sync status
status = await client.get_kev_sync_status()
# Get threat level distribution from CISA KEV data
threat_level_distribution = {}
from sqlalchemy import text
cves_with_kev = db.query(CVE).filter(
text("poc_data::text LIKE '%\"cisa_kev\"%'")
).all()
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_level = vuln_data.get('threat_level', 'unknown')
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
# Get vulnerability category distribution
category_distribution = {}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
category = vuln_data.get('vulnerability_category', 'unknown')
category_distribution[category] = category_distribution.get(category, 0) + 1
# Get ransomware usage statistics
ransomware_stats = {'known': 0, 'unknown': 0}
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
if ransomware_use == 'known':
ransomware_stats['known'] += 1
else:
ransomware_stats['unknown'] += 1
# Calculate average threat score
threat_scores = []
for cve in cves_with_kev:
if cve.poc_data and 'cisa_kev' in cve.poc_data:
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
threat_score = vuln_data.get('threat_score', 0)
if threat_score:
threat_scores.append(threat_score)
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
return {
"cisa_kev_sync_status": status,
"threat_level_distribution": threat_level_distribution,
"category_distribution": category_distribution,
"ransomware_stats": ransomware_stats,
"average_threat_score": round(avg_threat_score, 2),
"total_kev_cves": len(cves_with_kev),
"total_with_threat_scores": len(threat_scores)
}
except Exception as e:
logger.error(f"Error getting CISA KEV stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/bulk-jobs")
async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)):
"""Get bulk processing job status"""
jobs = db.query(BulkProcessingJob).order_by(
BulkProcessingJob.created_at.desc()
).limit(limit).all()
result = []
for job in jobs:
job_dict = {
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'metadata': job.job_metadata,
'started_at': job.started_at,
'completed_at': job.completed_at,
'created_at': job.created_at
}
result.append(job_dict)
return result
@app.get("/api/bulk-status")
async def get_bulk_status(db: Session = Depends(get_db)):
"""Get comprehensive bulk processing status"""
try:
from bulk_seeder import BulkSeeder
seeder = BulkSeeder(db)
status = await seeder.get_seeding_status()
return status
except Exception as e:
logger.error(f"Error getting bulk status: {e}")
return {"error": str(e)}
@app.get("/api/poc-stats")
async def get_poc_stats(db: Session = Depends(get_db)):
"""Get PoC-related statistics"""
try:
from nomi_sec_client import NomiSecClient
client = NomiSecClient(db)
stats = await client.get_sync_status()
# Additional PoC statistics
high_quality_cves = db.query(CVE).filter(
CVE.poc_count > 0,
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60
).count()
stats.update({
'high_quality_cves': high_quality_cves,
'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0
})
return stats
except Exception as e:
logger.error(f"Error getting PoC stats: {e}")
return {"error": str(e)}
@app.post("/api/sync-cve2capec")
async def sync_cve2capec(force_refresh: bool = False):
"""Synchronize CVE2CAPEC MITRE ATT&CK mappings using Celery task"""
try:
from celery_config import celery_app
from tasks.data_sync_tasks import sync_cve2capec_task
# Launch Celery task
task_result = sync_cve2capec_task.delay(force_refresh=force_refresh)
return {
"message": "CVE2CAPEC MITRE ATT&CK mapping sync started via Celery",
"status": "started",
"task_id": task_result.id,
"force_refresh": force_refresh,
"monitor_url": f"http://localhost:5555/task/{task_result.id}"
}
except ImportError as e:
logger.error(f"Failed to import Celery components: {e}")
raise HTTPException(status_code=500, detail="Celery not properly configured")
except Exception as e:
logger.error(f"Error starting CVE2CAPEC sync: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start CVE2CAPEC sync: {e}")
@app.post("/api/build-exploitdb-index")
async def build_exploitdb_index():
"""Build/rebuild ExploitDB file index using Celery task"""
try:
from celery_config import celery_app
from tasks.data_sync_tasks import build_exploitdb_index_task
# Launch Celery task
task_result = build_exploitdb_index_task.delay()
return {
"message": "ExploitDB file index build started via Celery",
"status": "started",
"task_id": task_result.id,
"monitor_url": f"http://localhost:5555/task/{task_result.id}"
}
except ImportError as e:
logger.error(f"Failed to import Celery components: {e}")
raise HTTPException(status_code=500, detail="Celery not properly configured")
except Exception as e:
logger.error(f"Error starting ExploitDB index build: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start ExploitDB index build: {e}")
@app.get("/api/cve2capec-stats")
async def get_cve2capec_stats():
"""Get CVE2CAPEC MITRE ATT&CK mapping statistics"""
try:
client = CVE2CAPECClient()
stats = client.get_stats()
return {
"status": "success",
"data": stats,
"description": "CVE to MITRE ATT&CK technique mappings from CVE2CAPEC repository"
}
except Exception as e:
logger.error(f"Error getting CVE2CAPEC stats: {e}")
return {"error": str(e)}
@app.post("/api/regenerate-rules")
async def regenerate_sigma_rules(background_tasks: BackgroundTasks,
request: RuleRegenRequest,
db: Session = Depends(get_db)):
"""Regenerate SIGMA rules using enhanced nomi-sec data"""
async def regenerate_task():
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
generator = EnhancedSigmaGenerator(db)
# Get CVEs with PoC data
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all()
rules_generated = 0
rules_updated = 0
for cve in cves_with_pocs:
# Check if we should regenerate
existing_rule = db.query(SigmaRule).filter(
SigmaRule.cve_id == cve.cve_id
).first()
if existing_rule and existing_rule.poc_source == 'nomi_sec' and not request.force:
continue
# Generate enhanced rule
result = await generator.generate_enhanced_rule(cve)
if result['success']:
if existing_rule:
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated")
except Exception as e:
logger.error(f"Rule regeneration failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(regenerate_task)
return {
"message": "SIGMA rule regeneration started",
"status": "started",
"force": request.force
}
@app.post("/api/llm-enhanced-rules")
async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Generate SIGMA rules using LLM API for enhanced analysis"""
# Parse request parameters
cve_id = request.get('cve_id')
force = request.get('force', False)
llm_provider = request.get('provider', os.getenv('LLM_PROVIDER'))
llm_model = request.get('model', os.getenv('LLM_MODEL'))
# Validation
if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id):
raise HTTPException(status_code=400, detail="Invalid CVE ID format")
async def llm_generation_task():
"""Background task for LLM-enhanced rule generation"""
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
generator = EnhancedSigmaGenerator(db, llm_provider, llm_model)
# Process specific CVE or all CVEs with PoC data
if cve_id:
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.error(f"CVE {cve_id} not found")
return
cves_to_process = [cve]
else:
# Process CVEs with PoC data that either have no rules or force update
query = db.query(CVE).filter(CVE.poc_count > 0)
if not force:
# Only process CVEs without existing LLM-generated rules
existing_llm_rules = db.query(SigmaRule).filter(
SigmaRule.detection_type.like('llm_%')
).all()
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids]
else:
cves_to_process = query.all()
logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}")
rules_generated = 0
rules_updated = 0
failures = 0
for cve in cves_to_process:
try:
# Check if CVE has sufficient PoC data
if not cve.poc_data or not cve.poc_count:
logger.debug(f"Skipping {cve.cve_id} - no PoC data")
continue
# Generate LLM-enhanced rule
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
if result.get('updated'):
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
else:
failures += 1
logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}")
except Exception as e:
failures += 1
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
continue
logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
except Exception as e:
logger.error(f"LLM-enhanced rule generation failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(llm_generation_task)
return {
"message": "LLM-enhanced SIGMA rule generation started",
"status": "started",
"cve_id": cve_id,
"force": force,
"provider": llm_provider,
"model": llm_model,
"note": "Requires appropriate LLM API key to be set"
}
@app.get("/api/llm-status")
async def get_llm_status():
"""Check LLM API availability status"""
try:
from llm_client import LLMClient
# Get current provider configuration
provider = os.getenv('LLM_PROVIDER')
model = os.getenv('LLM_MODEL')
client = LLMClient(provider=provider, model=model)
provider_info = client.get_provider_info()
# Get all available providers
all_providers = LLMClient.get_available_providers()
return {
"current_provider": provider_info,
"available_providers": all_providers,
"status": "ready" if client.is_available() else "unavailable"
}
except Exception as e:
logger.error(f"Error checking LLM status: {e}")
return {
"current_provider": {"provider": "unknown", "available": False},
"available_providers": [],
"status": "error",
"error": str(e)
}
@app.post("/api/llm-switch")
async def switch_llm_provider(request: dict):
"""Switch LLM provider and model"""
try:
from llm_client import LLMClient
provider = request.get('provider')
model = request.get('model')
if not provider:
raise HTTPException(status_code=400, detail="Provider is required")
# Validate provider
if provider not in LLMClient.SUPPORTED_PROVIDERS:
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
# Test the new configuration
client = LLMClient(provider=provider, model=model)
if not client.is_available():
raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured")
# Update environment variables (note: this only affects the current session)
os.environ['LLM_PROVIDER'] = provider
if model:
os.environ['LLM_MODEL'] = model
provider_info = client.get_provider_info()
return {
"message": f"Switched to {provider}",
"provider_info": provider_info,
"status": "success"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error switching LLM provider: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/cancel-job/{job_id}")
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
"""Cancel a running job"""
try:
# Find the job in the database
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status not in ['pending', 'running']:
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
# Set cancellation flag
job_cancellation_flags[job_id] = True
# Update job status
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
db.commit()
logger.info(f"Job {job_id} cancellation requested")
return {
"message": f"Job {job_id} cancellation requested",
"status": "cancelled",
"job_id": job_id
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling job {job_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/running-jobs")
async def get_running_jobs(db: Session = Depends(get_db)):
"""Get all currently running jobs"""
try:
jobs = db.query(BulkProcessingJob).filter(
BulkProcessingJob.status.in_(['pending', 'running'])
).order_by(BulkProcessingJob.created_at.desc()).all()
result = []
for job in jobs:
result.append({
'id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'year': job.year,
'total_items': job.total_items,
'processed_items': job.processed_items,
'failed_items': job.failed_items,
'error_message': job.error_message,
'started_at': job.started_at,
'created_at': job.created_at,
'can_cancel': job.status in ['pending', 'running']
})
return result
except Exception as e:
logger.error(f"Error getting running jobs: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/ollama-pull-model")
async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks):
"""Pull an Ollama model"""
try:
from llm_client import LLMClient
model = request.get('model')
if not model:
raise HTTPException(status_code=400, detail="Model name is required")
# Create a background task to pull the model
def pull_model_task():
try:
client = LLMClient(provider='ollama', model=model)
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
if client._pull_ollama_model(base_url, model):
logger.info(f"Successfully pulled Ollama model: {model}")
else:
logger.error(f"Failed to pull Ollama model: {model}")
except Exception as e:
logger.error(f"Error in model pull task: {e}")
background_tasks.add_task(pull_model_task)
return {
"message": f"Started pulling model {model}",
"status": "started",
"model": model
}
except Exception as e:
logger.error(f"Error starting model pull: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/ollama-models")
async def get_ollama_models():
"""Get available Ollama models"""
try:
from llm_client import LLMClient
client = LLMClient(provider='ollama')
available_models = client._get_ollama_available_models()
return {
"available_models": available_models,
"total_models": len(available_models),
"status": "success"
}
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# NOTE: SCHEDULER ENDPOINTS REMOVED
# ============================================================================
#
# Job scheduling is now handled by Celery Beat with periodic tasks.
# All scheduled tasks are defined in celery_config.py beat_schedule.
#
# To manage scheduled tasks:
# - View tasks: Use Celery monitoring tools (Flower, Celery events)
# - Control tasks: Use Celery control commands or through Celery job management endpoints
# - Schedule changes: Update celery_config.py and restart Celery Beat
#
# Available Celery job management endpoints:
# - GET /api/celery/tasks - List all active tasks
# - POST /api/celery/tasks/{task_id}/revoke - Cancel a running task
# - GET /api/celery/workers - View worker status
#
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)