diff --git a/backend/config/database.py b/backend/config/database.py index e55a0e2..c9acf1a 100644 --- a/backend/config/database.py +++ b/backend/config/database.py @@ -1,10 +1,20 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.pool import QueuePool from .settings import settings -# Database setup -engine = create_engine(settings.DATABASE_URL) +# Database setup with connection pooling +engine = create_engine( + settings.DATABASE_URL, + poolclass=QueuePool, + pool_size=10, # Number of connections to maintain in the pool + max_overflow=20, # Additional connections that can be created on demand + pool_timeout=30, # Timeout for getting connection from pool + pool_recycle=3600, # Recycle connections after 1 hour + pool_pre_ping=True, # Validate connections before use + echo=False # Set to True for SQL query logging +) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/backend/delete_sigma_rules.py b/backend/delete_sigma_rules.py index 6d0baa0..57557c3 100644 --- a/backend/delete_sigma_rules.py +++ b/backend/delete_sigma_rules.py @@ -4,7 +4,8 @@ Script to delete all SIGMA rules from the database This will clear existing rules so they can be regenerated with the improved LLM client """ -from models import SigmaRule, SessionLocal +from models import SigmaRule +from config.database import SessionLocal import logging # Setup logging diff --git a/backend/llm_client.py b/backend/llm_client.py index f946ab1..5b1648a 100644 --- a/backend/llm_client.py +++ b/backend/llm_client.py @@ -108,7 +108,10 @@ class LLMClient: self.llm = Ollama( model=self.model, base_url=base_url, - temperature=0.1 + temperature=0.1, + num_ctx=4096, # Context window size + top_p=0.9, + top_k=40 ) if self.llm: @@ -186,9 +189,22 @@ class LLMClient: logger.info(f"CVE Description for {cve_id}: {cve_description[:200]}...") logger.info(f"PoC Content sample for {cve_id}: {poc_content[:200]}...") - # Generate the response + # Generate the response with timeout handling logger.info(f"Final prompt variables for {cve_id}: {list(input_data.keys())}") - response = await chain.ainvoke(input_data) + + import asyncio + try: + # Add timeout wrapper around the LLM call + response = await asyncio.wait_for( + chain.ainvoke(input_data), + timeout=150 # 2.5 minutes total timeout + ) + except asyncio.TimeoutError: + logger.error(f"LLM request timed out for {cve_id}") + return None + except Exception as llm_error: + logger.error(f"LLM generation error for {cve_id}: {llm_error}") + return None # Debug: Log raw LLM response logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...") @@ -228,36 +244,42 @@ class LLMClient: system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification. -**CRITICAL: You must follow the exact SIGMA specification format:** +**OFFICIAL SIGMA RULE SPECIFICATION JSON SCHEMA:** +The official SIGMA rule specification (v2.0.0) defines these requirements: -1. **YAML Structure Requirements:** - - Use UTF-8 encoding with LF line breaks - - Indent with 4 spaces (no tabs) - - Use lowercase keys only - - Use single quotes for string values - - No quotes for numeric values - - Follow proper YAML syntax +**MANDATORY Fields (must include):** +- title: Brief description (max 256 chars) - string +- logsource: Log data source specification - object with category/product/service +- detection: Search identifiers and conditions - object with selections and condition -2. **MANDATORY Fields (must include):** - - title: Brief description (max 256 chars) - - logsource: Log data source specification - - detection: Search identifiers and conditions - - condition: How detection elements combine +**RECOMMENDED Fields:** +- id: Unique UUID (version 4) - string with UUID format +- status: Rule state - enum: "stable", "test", "experimental", "deprecated", "unsupported" +- description: Detailed explanation - string +- author: Rule creator - string (use "AI Generated") +- date: Creation date - string in YYYY/MM/DD format +- modified: Last modification date - string in YYYY/MM/DD format +- references: Sources for rule derivation - array of strings (URLs) +- tags: MITRE ATT&CK techniques - array of strings +- level: Rule severity - enum: "informational", "low", "medium", "high", "critical" +- falsepositives: Known false positives - array of strings +- fields: Related fields - array of strings +- related: Related rules - array of objects with type and id -3. **RECOMMENDED Fields:** - - id: Unique UUID - - status: 'experimental' (for new rules) - - description: Detailed explanation - - author: 'AI Generated' - - date: Current date (YYYY/MM/DD) - - references: Array with CVE link - - tags: MITRE ATT&CK techniques +**YAML Structure Requirements:** +- Use UTF-8 encoding with LF line breaks +- Indent with 4 spaces (no tabs) +- Use lowercase keys only +- Use single quotes for string values +- No quotes for numeric values +- Follow proper YAML syntax -4. **Detection Structure:** - - Use selection blocks (selection, selection1, etc.) - - Condition references these selections - - Use proper field names (Image, CommandLine, ProcessName, etc.) - - Support wildcards (*) and value lists +**Detection Structure:** +- Use selection blocks (selection, selection1, etc.) +- Condition references these selections +- Use proper field names (Image, CommandLine, ProcessName, etc.) +- Support wildcards (*) and value lists +- Condition can be string expression or object with keywords **ABSOLUTE REQUIREMENTS:** - Output ONLY valid YAML @@ -1253,4 +1275,115 @@ Output ONLY the enhanced SIGMA rule in valid YAML format.""" return [] except Exception as e: logger.error(f"Error getting Ollama models: {e}") - return [] \ No newline at end of file + return [] + + async def test_connection(self) -> Dict[str, Any]: + """Test connection to the configured LLM provider.""" + try: + if self.provider == 'openai': + api_key = os.getenv('OPENAI_API_KEY') + if not api_key: + return { + "available": False, + "error": "OpenAI API key not configured", + "models": [], + "current_model": self.model, + "has_api_key": False + } + + # Test OpenAI connection without actual API call to avoid timeouts + if self.llm: + return { + "available": True, + "models": self.SUPPORTED_PROVIDERS['openai']['models'], + "current_model": self.model, + "has_api_key": True + } + else: + return { + "available": False, + "error": "OpenAI client not initialized", + "models": [], + "current_model": self.model, + "has_api_key": True + } + + elif self.provider == 'anthropic': + api_key = os.getenv('ANTHROPIC_API_KEY') + if not api_key: + return { + "available": False, + "error": "Anthropic API key not configured", + "models": [], + "current_model": self.model, + "has_api_key": False + } + + # Test Anthropic connection without actual API call to avoid timeouts + if self.llm: + return { + "available": True, + "models": self.SUPPORTED_PROVIDERS['anthropic']['models'], + "current_model": self.model, + "has_api_key": True + } + else: + return { + "available": False, + "error": "Anthropic client not initialized", + "models": [], + "current_model": self.model, + "has_api_key": True + } + + elif self.provider == 'ollama': + base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') + + # Test Ollama connection + try: + import requests + response = requests.get(f"{base_url}/api/tags", timeout=10) + if response.status_code == 200: + available_models = self._get_ollama_available_models() + # Check if model is available using proper model name matching + model_available = self._check_ollama_model_available(base_url, self.model) + + return { + "available": model_available, + "models": available_models, + "current_model": self.model, + "base_url": base_url, + "error": None if model_available else f"Model {self.model} not available" + } + else: + return { + "available": False, + "error": f"Ollama server not responding (HTTP {response.status_code})", + "models": [], + "current_model": self.model, + "base_url": base_url + } + except Exception as e: + return { + "available": False, + "error": f"Cannot connect to Ollama server: {str(e)}", + "models": [], + "current_model": self.model, + "base_url": base_url + } + + else: + return { + "available": False, + "error": f"Unsupported provider: {self.provider}", + "models": [], + "current_model": self.model + } + + except Exception as e: + return { + "available": False, + "error": f"Connection test failed: {str(e)}", + "models": [], + "current_model": self.model + } \ No newline at end of file diff --git a/backend/main_legacy.py b/backend/main_legacy.py deleted file mode 100644 index 78dae8c..0000000 --- a/backend/main_legacy.py +++ /dev/null @@ -1,2373 +0,0 @@ -from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session -from sqlalchemy.dialects.postgresql import UUID -import uuid -from datetime import datetime, timedelta -import requests -import json -import re -import os -from typing import List, Optional -from pydantic import BaseModel -import asyncio -from contextlib import asynccontextmanager -import base64 -from github import Github -from urllib.parse import urlparse -import hashlib -import logging -import threading -from mcdevitt_poc_client import GitHubPoCClient -from cve2capec_client import CVE2CAPECClient - -# Setup logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Global job tracking -running_jobs = {} -job_cancellation_flags = {} - -# Database setup -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db") -engine = create_engine(DATABASE_URL) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() - -# Database Models -class CVE(Base): - __tablename__ = "cves" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - cve_id = Column(String(20), unique=True, nullable=False) - description = Column(Text) - cvss_score = Column(DECIMAL(3, 1)) - severity = Column(String(20)) - published_date = Column(TIMESTAMP) - modified_date = Column(TIMESTAMP) - affected_products = Column(ARRAY(String)) - reference_urls = Column(ARRAY(String)) - # Bulk processing fields - data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual' - nvd_json_version = Column(String(10), default='2.0') - bulk_processed = Column(Boolean, default=False) - # nomi-sec PoC fields - poc_count = Column(Integer, default=0) - poc_data = Column(JSON) # Store nomi-sec PoC metadata - # Reference data fields - reference_data = Column(JSON) # Store extracted reference content and analysis - reference_sync_status = Column(String(20), default='pending') # 'pending', 'processing', 'completed', 'failed' - reference_last_synced = Column(TIMESTAMP) - created_at = Column(TIMESTAMP, default=datetime.utcnow) - updated_at = Column(TIMESTAMP, default=datetime.utcnow) - -class SigmaRule(Base): - __tablename__ = "sigma_rules" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - cve_id = Column(String(20)) - rule_name = Column(String(255), nullable=False) - rule_content = Column(Text, nullable=False) - detection_type = Column(String(50)) - log_source = Column(String(100)) - confidence_level = Column(String(20)) - auto_generated = Column(Boolean, default=True) - exploit_based = Column(Boolean, default=False) - github_repos = Column(ARRAY(String)) - exploit_indicators = Column(Text) # JSON string of extracted indicators - # Enhanced fields for new data sources - poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual' - poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc. - nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata - created_at = Column(TIMESTAMP, default=datetime.utcnow) - updated_at = Column(TIMESTAMP, default=datetime.utcnow) - -class RuleTemplate(Base): - __tablename__ = "rule_templates" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - template_name = Column(String(255), nullable=False) - template_content = Column(Text, nullable=False) - applicable_product_patterns = Column(ARRAY(String)) - description = Column(Text) - created_at = Column(TIMESTAMP, default=datetime.utcnow) - -class BulkProcessingJob(Base): - __tablename__ = "bulk_processing_jobs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update' - status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled' - year = Column(Integer) # For year-based processing - total_items = Column(Integer, default=0) - processed_items = Column(Integer, default=0) - failed_items = Column(Integer, default=0) - error_message = Column(Text) - job_metadata = Column(JSON) # Additional job-specific data - started_at = Column(TIMESTAMP) - completed_at = Column(TIMESTAMP) - cancelled_at = Column(TIMESTAMP) - created_at = Column(TIMESTAMP, default=datetime.utcnow) - -# Pydantic models -class CVEResponse(BaseModel): - id: str - cve_id: str - description: Optional[str] = None - cvss_score: Optional[float] = None - severity: Optional[str] = None - published_date: Optional[datetime] = None - affected_products: Optional[List[str]] = None - reference_urls: Optional[List[str]] = None - - class Config: - from_attributes = True - -class SigmaRuleResponse(BaseModel): - id: str - cve_id: str - rule_name: str - rule_content: str - detection_type: Optional[str] = None - log_source: Optional[str] = None - confidence_level: Optional[str] = None - auto_generated: bool = True - exploit_based: bool = False - github_repos: Optional[List[str]] = None - exploit_indicators: Optional[str] = None - created_at: datetime - - class Config: - from_attributes = True - -# Request models -class BulkSeedRequest(BaseModel): - start_year: int = 2002 - end_year: Optional[int] = None - skip_nvd: bool = False - skip_nomi_sec: bool = True - -class NomiSecSyncRequest(BaseModel): - cve_id: Optional[str] = None - batch_size: int = 50 - -class GitHubPoCSyncRequest(BaseModel): - cve_id: Optional[str] = None - batch_size: int = 50 - -class ExploitDBSyncRequest(BaseModel): - cve_id: Optional[str] = None - batch_size: int = 30 - -class CISAKEVSyncRequest(BaseModel): - cve_id: Optional[str] = None - batch_size: int = 100 - -class ReferenceSyncRequest(BaseModel): - cve_id: Optional[str] = None - batch_size: int = 30 - max_cves: Optional[int] = None - force_resync: bool = False - -class RuleRegenRequest(BaseModel): - force: bool = False - -# GitHub Exploit Analysis Service -class GitHubExploitAnalyzer: - def __init__(self): - self.github_token = os.getenv("GITHUB_TOKEN") - self.github = Github(self.github_token) if self.github_token else None - - async def search_exploits_for_cve(self, cve_id: str) -> List[dict]: - """Search GitHub for exploit code related to a CVE""" - if not self.github: - print(f"No GitHub token configured, skipping exploit search for {cve_id}") - return [] - - try: - print(f"Searching GitHub for exploits for {cve_id}") - - # Search queries to find exploit code - search_queries = [ - f"{cve_id} exploit", - f"{cve_id} poc", - f"{cve_id} vulnerability", - f'"{cve_id}" exploit code', - f"{cve_id.replace('-', '_')} exploit" - ] - - exploits = [] - seen_repos = set() - - for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits - try: - # Search repositories - repos = self.github.search_repositories( - query=query, - sort="updated", - order="desc" - ) - - # Get top 5 results per query - for repo in repos[:5]: - if repo.full_name in seen_repos: - continue - seen_repos.add(repo.full_name) - - # Analyze repository - exploit_info = await self._analyze_repository(repo, cve_id) - if exploit_info: - exploits.append(exploit_info) - - if len(exploits) >= 10: # Limit total exploits - break - - if len(exploits) >= 10: - break - - except Exception as e: - print(f"Error searching GitHub with query '{query}': {str(e)}") - continue - - print(f"Found {len(exploits)} potential exploits for {cve_id}") - return exploits - - except Exception as e: - print(f"Error searching GitHub for {cve_id}: {str(e)}") - return [] - - async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]: - """Analyze a GitHub repository for exploit code""" - try: - # Check if repo name or description mentions the CVE - repo_text = f"{repo.name} {repo.description or ''}".lower() - if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text: - return None - - # Get repository contents - exploit_files = [] - indicators = { - 'processes': set(), - 'files': set(), - 'registry': set(), - 'network': set(), - 'commands': set(), - 'powershell': set(), - 'urls': set() - } - - try: - contents = repo.get_contents("") - for content in contents[:20]: # Limit files to analyze - if content.type == "file" and self._is_exploit_file(content.name): - file_analysis = await self._analyze_file_content(repo, content, cve_id) - if file_analysis: - exploit_files.append(file_analysis) - # Merge indicators - for key, values in file_analysis.get('indicators', {}).items(): - if key in indicators: - indicators[key].update(values) - - except Exception as e: - print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}") - - if not exploit_files: - return None - - return { - 'repo_name': repo.full_name, - 'repo_url': repo.html_url, - 'description': repo.description, - 'language': repo.language, - 'stars': repo.stargazers_count, - 'updated': repo.updated_at.isoformat(), - 'files': exploit_files, - 'indicators': {k: list(v) for k, v in indicators.items()} - } - - except Exception as e: - print(f"Error analyzing repository {repo.full_name}: {str(e)}") - return None - - def _is_exploit_file(self, filename: str) -> bool: - """Check if a file is likely to contain exploit code""" - exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java'] - exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack'] - - filename_lower = filename.lower() - - # Check extension - if not any(filename_lower.endswith(ext) for ext in exploit_extensions): - return False - - # Check filename for exploit-related terms - return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower - - async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]: - """Analyze individual file content for exploit indicators""" - try: - if file_content.size > 100000: # Skip files larger than 100KB - return None - - # Decode file content - content = file_content.decoded_content.decode('utf-8', errors='ignore') - - # Check if file actually mentions the CVE - if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower(): - return None - - indicators = self._extract_indicators_from_code(content, file_content.name) - - if not any(indicators.values()): - return None - - return { - 'filename': file_content.name, - 'path': file_content.path, - 'size': file_content.size, - 'indicators': indicators - } - - except Exception as e: - print(f"Error analyzing file {file_content.name}: {str(e)}") - return None - - def _extract_indicators_from_code(self, content: str, filename: str) -> dict: - """Extract security indicators from exploit code""" - indicators = { - 'processes': set(), - 'files': set(), - 'registry': set(), - 'network': set(), - 'commands': set(), - 'powershell': set(), - 'urls': set() - } - - # Process patterns - process_patterns = [ - r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']', - r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']', - r'system\s*\(\s*["\']([^"\']+)["\']', - r'exec\s*\(\s*["\']([^"\']+)["\']', - r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']' - ] - - # File patterns - file_patterns = [ - r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']', - r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?', - r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)', - r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)' - ] - - # Registry patterns - registry_patterns = [ - r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']', - r'HKEY_[A-Z_]+\\\\([^"\'\\]+)', - r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?' - ] - - # Network patterns - network_patterns = [ - r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)', - r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)', - r'(?:http|https|ftp)://([^\s"\'<>]+)', - r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)' - ] - - # PowerShell patterns - powershell_patterns = [ - r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?', - r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?', - r'Start-Process\s+["\']?([^"\']+)["\']?', - r'Get-Process\s+["\']?([^"\']+)["\']?' - ] - - # Command patterns - command_patterns = [ - r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?', - r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)', - r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)' - ] - - # Extract indicators using regex patterns - patterns = { - 'processes': process_patterns, - 'files': file_patterns, - 'registry': registry_patterns, - 'powershell': powershell_patterns, - 'commands': command_patterns - } - - for category, pattern_list in patterns.items(): - for pattern in pattern_list: - matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE) - for match in matches: - if isinstance(match, tuple): - indicators[category].add(match[0]) - else: - indicators[category].add(match) - - # Special handling for network indicators - for pattern in network_patterns: - matches = re.findall(pattern, content, re.IGNORECASE) - for match in matches: - if isinstance(match, tuple): - if len(match) >= 2: - indicators['network'].add(f"{match[0]}:{match[1]}") - else: - indicators['network'].add(match[0]) - else: - indicators['network'].add(match) - - # Convert sets to lists and filter out empty/invalid indicators - cleaned_indicators = {} - for key, values in indicators.items(): - cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200] - if cleaned_values: - cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category - - return cleaned_indicators -class CVESigmaService: - def __init__(self, db: Session): - self.db = db - self.nvd_api_key = os.getenv("NVD_API_KEY") - - async def fetch_recent_cves(self, days_back: int = 7): - """Fetch recent CVEs from NVD API""" - end_date = datetime.utcnow() - start_date = end_date - timedelta(days=days_back) - - url = "https://services.nvd.nist.gov/rest/json/cves/2.0" - params = { - "pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", - "pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", - "resultsPerPage": 100 - } - - headers = {} - if self.nvd_api_key: - headers["apiKey"] = self.nvd_api_key - - try: - response = requests.get(url, params=params, headers=headers, timeout=30) - response.raise_for_status() - data = response.json() - - new_cves = [] - for vuln in data.get("vulnerabilities", []): - cve_data = vuln.get("cve", {}) - cve_id = cve_data.get("id") - - # Check if CVE already exists - existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first() - if existing: - continue - - # Extract CVE information - description = "" - if cve_data.get("descriptions"): - description = cve_data["descriptions"][0].get("value", "") - - cvss_score = None - severity = None - if cve_data.get("metrics", {}).get("cvssMetricV31"): - cvss_data = cve_data["metrics"]["cvssMetricV31"][0] - cvss_score = cvss_data.get("cvssData", {}).get("baseScore") - severity = cvss_data.get("cvssData", {}).get("baseSeverity") - - affected_products = [] - if cve_data.get("configurations"): - for config in cve_data["configurations"]: - for node in config.get("nodes", []): - for cpe_match in node.get("cpeMatch", []): - if cpe_match.get("vulnerable"): - affected_products.append(cpe_match.get("criteria", "")) - - reference_urls = [] - if cve_data.get("references"): - reference_urls = [ref.get("url", "") for ref in cve_data["references"]] - - cve_obj = CVE( - cve_id=cve_id, - description=description, - cvss_score=cvss_score, - severity=severity, - published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")), - modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")), - affected_products=affected_products, - reference_urls=reference_urls - ) - - self.db.add(cve_obj) - new_cves.append(cve_obj) - - self.db.commit() - return new_cves - - except Exception as e: - print(f"Error fetching CVEs: {str(e)}") - return [] - - def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]: - """Generate SIGMA rule based on CVE data""" - if not cve.description: - return None - - # Analyze CVE to determine appropriate template - description_lower = cve.description.lower() - affected_products = [p.lower() for p in (cve.affected_products or [])] - - template = self._select_template(description_lower, affected_products) - if not template: - return None - - # Generate rule content - rule_content = self._populate_template(cve, template) - if not rule_content: - return None - - # Determine detection type and confidence - detection_type = self._determine_detection_type(description_lower) - confidence_level = self._calculate_confidence(cve) - - sigma_rule = SigmaRule( - cve_id=cve.cve_id, - rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection", - rule_content=rule_content, - detection_type=detection_type, - log_source=template.template_name.lower().replace(" ", "_"), - confidence_level=confidence_level, - auto_generated=True - ) - - self.db.add(sigma_rule) - return sigma_rule - - def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None): - """Select appropriate SIGMA rule template based on CVE and exploit analysis""" - templates = self.db.query(RuleTemplate).all() - - # If we have exploit indicators, use them to determine the best template - if exploit_indicators: - if exploit_indicators.get('powershell'): - powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None) - if powershell_template: - return powershell_template - - if exploit_indicators.get('network'): - network_template = next((t for t in templates if "Network Connection" in t.template_name), None) - if network_template: - return network_template - - if exploit_indicators.get('files'): - file_template = next((t for t in templates if "File Modification" in t.template_name), None) - if file_template: - return file_template - - if exploit_indicators.get('processes') or exploit_indicators.get('commands'): - process_template = next((t for t in templates if "Process Execution" in t.template_name), None) - if process_template: - return process_template - - # Fallback to original logic - if any("windows" in p or "microsoft" in p for p in affected_products): - if "process" in description or "execution" in description: - return next((t for t in templates if "Process Execution" in t.template_name), None) - elif "network" in description or "remote" in description: - return next((t for t in templates if "Network Connection" in t.template_name), None) - elif "file" in description or "write" in description: - return next((t for t in templates if "File Modification" in t.template_name), None) - - # Default to process execution template - return next((t for t in templates if "Process Execution" in t.template_name), None) - - def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str: - """Populate template with CVE-specific data and exploit indicators""" - try: - # Use exploit indicators if available, otherwise extract from description - if exploit_indicators: - suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', []) - suspicious_ports = [] - file_patterns = exploit_indicators.get('files', []) - - # Extract ports from network indicators - for net_indicator in exploit_indicators.get('network', []): - if ':' in str(net_indicator): - try: - port = int(str(net_indicator).split(':')[-1]) - suspicious_ports.append(port) - except ValueError: - pass - else: - # Fallback to original extraction - suspicious_processes = self._extract_suspicious_indicators(cve.description, "process") - suspicious_ports = self._extract_suspicious_indicators(cve.description, "port") - file_patterns = self._extract_suspicious_indicators(cve.description, "file") - - # Determine severity level - level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium" - - # Create enhanced description - enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description - if exploit_indicators: - enhanced_description += " [Enhanced with GitHub exploit analysis]" - - # Build tags - tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()] - if exploit_indicators: - tags.append("exploit.github") - - rule_content = template.template_content.format( - title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection", - description=enhanced_description, - rule_id=str(uuid.uuid4()), - date=datetime.utcnow().strftime("%Y/%m/%d"), - cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}", - cve_id=cve.cve_id.lower(), - tags="\n - ".join(tags), - suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"], - suspicious_ports=suspicious_ports or [4444, 8080, 9999], - file_patterns=file_patterns or ["temp", "malware", "exploit"], - level=level - ) - - return rule_content - - except Exception as e: - print(f"Error populating template: {str(e)}") - return None - - def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str: - """Map CVE and exploit indicators to MITRE ATT&CK techniques""" - desc_lower = description.lower() - - # Check exploit indicators first - if exploit_indicators: - if exploit_indicators.get('powershell'): - return "t1059.001" # PowerShell - elif exploit_indicators.get('commands'): - return "t1059.003" # Windows Command Shell - elif exploit_indicators.get('network'): - return "t1071.001" # Web Protocols - elif exploit_indicators.get('files'): - return "t1105" # Ingress Tool Transfer - elif exploit_indicators.get('processes'): - return "t1106" # Native API - - # Fallback to description analysis - if "powershell" in desc_lower: - return "t1059.001" - elif "command" in desc_lower or "cmd" in desc_lower: - return "t1059.003" - elif "network" in desc_lower or "remote" in desc_lower: - return "t1071.001" - elif "file" in desc_lower or "upload" in desc_lower: - return "t1105" - elif "process" in desc_lower or "execution" in desc_lower: - return "t1106" - else: - return "execution" # Generic - - def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List: - """Extract suspicious indicators from CVE description""" - if indicator_type == "process": - # Look for executable names or process patterns - exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE) - return exe_pattern[:5] if exe_pattern else None - - elif indicator_type == "port": - # Look for port numbers - port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE) - return [int(p) for p in port_pattern[:3]] if port_pattern else None - - elif indicator_type == "file": - # Look for file extensions or paths - file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE) - return file_pattern[:5] if file_pattern else None - - return None - - def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str: - """Determine detection type based on CVE description and exploit indicators""" - if exploit_indicators: - if exploit_indicators.get('powershell'): - return "powershell" - elif exploit_indicators.get('network'): - return "network" - elif exploit_indicators.get('files'): - return "file" - elif exploit_indicators.get('processes') or exploit_indicators.get('commands'): - return "process" - - # Fallback to original logic - if "remote" in description or "network" in description: - return "network" - elif "process" in description or "execution" in description: - return "process" - elif "file" in description or "filesystem" in description: - return "file" - else: - return "general" - - def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str: - """Calculate confidence level for the generated rule""" - base_confidence = 0 - - # CVSS score contributes to confidence - if cve.cvss_score: - if cve.cvss_score >= 9.0: - base_confidence += 3 - elif cve.cvss_score >= 7.0: - base_confidence += 2 - else: - base_confidence += 1 - - # Exploit-based rules get higher confidence - if exploit_based: - base_confidence += 2 - - # Map to confidence levels - if base_confidence >= 4: - return "high" - elif base_confidence >= 2: - return "medium" - else: - return "low" - -# Dependency -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - -# Background task to fetch CVEs and generate rules -async def background_cve_fetch(): - retry_count = 0 - max_retries = 3 - - while True: - try: - db = SessionLocal() - service = CVESigmaService(db) - current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') - print(f"[{current_time}] Starting CVE fetch cycle...") - - # Use a longer initial period (30 days) to find CVEs - new_cves = await service.fetch_recent_cves(days_back=30) - - if new_cves: - print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...") - rules_generated = 0 - for cve in new_cves: - try: - sigma_rule = service.generate_sigma_rule(cve) - if sigma_rule: - rules_generated += 1 - print(f"Generated SIGMA rule for {cve.cve_id}") - else: - print(f"Could not generate rule for {cve.cve_id} - insufficient data") - except Exception as e: - print(f"Error generating rule for {cve.cve_id}: {str(e)}") - - db.commit() - print(f"Successfully generated {rules_generated} SIGMA rules") - retry_count = 0 # Reset retry count on success - else: - print("No new CVEs found in this cycle") - # After first successful run, reduce to 7 days for regular updates - if retry_count == 0: - print("Switching to 7-day lookback for future runs...") - - db.close() - - except Exception as e: - retry_count += 1 - print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}") - if retry_count >= max_retries: - print(f"Max retries reached, waiting longer before next attempt...") - await asyncio.sleep(1800) # Wait 30 minutes on repeated failures - retry_count = 0 - else: - await asyncio.sleep(300) # Wait 5 minutes before retry - continue - - # Wait 1 hour before next fetch (or 30 minutes if there were errors) - wait_time = 3600 if retry_count == 0 else 1800 - print(f"Next CVE fetch in {wait_time//60} minutes...") - await asyncio.sleep(wait_time) - -@asynccontextmanager -async def lifespan(app: FastAPI): - # Initialize database - Base.metadata.create_all(bind=engine) - - # Initialize rule templates - db = SessionLocal() - try: - existing_templates = db.query(RuleTemplate).count() - if existing_templates == 0: - logger.info("No rule templates found. Database initialization will handle template creation.") - except Exception as e: - logger.error(f"Error checking rule templates: {e}") - finally: - db.close() - - # Initialize and start the job scheduler - try: - from job_scheduler import initialize_scheduler - from job_executors import register_all_executors - - # Initialize scheduler - scheduler = initialize_scheduler() - scheduler.set_db_session_factory(SessionLocal) - - # Register all job executors - register_all_executors(scheduler) - - # Start the scheduler - scheduler.start() - - logger.info("Job scheduler initialized and started") - - except Exception as e: - logger.error(f"Error initializing job scheduler: {e}") - - yield - - # Shutdown - try: - from job_scheduler import get_scheduler - scheduler = get_scheduler() - scheduler.stop() - logger.info("Job scheduler stopped") - except Exception as e: - logger.error(f"Error stopping job scheduler: {e}") - -# FastAPI app -app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan) - -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:3000"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -@app.get("/api/cves", response_model=List[CVEResponse]) -async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): - cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all() - # Convert UUID to string for each CVE - result = [] - for cve in cves: - cve_dict = { - 'id': str(cve.id), - 'cve_id': cve.cve_id, - 'description': cve.description, - 'cvss_score': float(cve.cvss_score) if cve.cvss_score else None, - 'severity': cve.severity, - 'published_date': cve.published_date, - 'affected_products': cve.affected_products, - 'reference_urls': cve.reference_urls - } - result.append(CVEResponse(**cve_dict)) - return result - -@app.get("/api/cves/{cve_id}", response_model=CVEResponse) -async def get_cve(cve_id: str, db: Session = Depends(get_db)): - cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() - if not cve: - raise HTTPException(status_code=404, detail="CVE not found") - - cve_dict = { - 'id': str(cve.id), - 'cve_id': cve.cve_id, - 'description': cve.description, - 'cvss_score': float(cve.cvss_score) if cve.cvss_score else None, - 'severity': cve.severity, - 'published_date': cve.published_date, - 'affected_products': cve.affected_products, - 'reference_urls': cve.reference_urls - } - return CVEResponse(**cve_dict) - -@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse]) -async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)): - rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all() - # Convert UUID to string for each rule - result = [] - for rule in rules: - rule_dict = { - 'id': str(rule.id), - 'cve_id': rule.cve_id, - 'rule_name': rule.rule_name, - 'rule_content': rule.rule_content, - 'detection_type': rule.detection_type, - 'log_source': rule.log_source, - 'confidence_level': rule.confidence_level, - 'auto_generated': rule.auto_generated, - 'exploit_based': rule.exploit_based or False, - 'github_repos': rule.github_repos or [], - 'exploit_indicators': rule.exploit_indicators, - 'created_at': rule.created_at - } - result.append(SigmaRuleResponse(**rule_dict)) - return result - -@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse]) -async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)): - rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all() - # Convert UUID to string for each rule - result = [] - for rule in rules: - rule_dict = { - 'id': str(rule.id), - 'cve_id': rule.cve_id, - 'rule_name': rule.rule_name, - 'rule_content': rule.rule_content, - 'detection_type': rule.detection_type, - 'log_source': rule.log_source, - 'confidence_level': rule.confidence_level, - 'auto_generated': rule.auto_generated, - 'exploit_based': rule.exploit_based or False, - 'github_repos': rule.github_repos or [], - 'exploit_indicators': rule.exploit_indicators, - 'created_at': rule.created_at - } - result.append(SigmaRuleResponse(**rule_dict)) - return result - -@app.post("/api/fetch-cves") -async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): - async def fetch_task(): - try: - service = CVESigmaService(db) - print("Manual CVE fetch initiated...") - # Use 30 days for manual fetch to get more results - new_cves = await service.fetch_recent_cves(days_back=30) - - rules_generated = 0 - for cve in new_cves: - sigma_rule = service.generate_sigma_rule(cve) - if sigma_rule: - rules_generated += 1 - - db.commit() - print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated") - except Exception as e: - print(f"Manual fetch error: {str(e)}") - import traceback - traceback.print_exc() - - background_tasks.add_task(fetch_task) - return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"} - -@app.get("/api/test-nvd") -async def test_nvd_connection(): - """Test endpoint to check NVD API connectivity""" - try: - # Test with a simple request using current date - end_date = datetime.utcnow() - start_date = end_date - timedelta(days=30) - - url = "https://services.nvd.nist.gov/rest/json/cves/2.0/" - params = { - "lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"), - "lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"), - "resultsPerPage": 5, - "startIndex": 0 - } - - headers = { - "User-Agent": "CVE-SIGMA-Generator/1.0", - "Accept": "application/json" - } - - nvd_api_key = os.getenv("NVD_API_KEY") - if nvd_api_key: - headers["apiKey"] = nvd_api_key - - print(f"Testing NVD API with URL: {url}") - print(f"Test params: {params}") - print(f"Test headers: {headers}") - - response = requests.get(url, params=params, headers=headers, timeout=15) - - result = { - "status": "success" if response.status_code == 200 else "error", - "status_code": response.status_code, - "has_api_key": bool(nvd_api_key), - "request_url": f"{url}?{requests.compat.urlencode(params)}", - "response_headers": dict(response.headers) - } - - if response.status_code == 200: - data = response.json() - result.update({ - "total_results": data.get("totalResults", 0), - "results_per_page": data.get("resultsPerPage", 0), - "vulnerabilities_returned": len(data.get("vulnerabilities", [])), - "message": "NVD API is accessible and returning data" - }) - else: - result.update({ - "error_message": response.text[:200], - "message": f"NVD API returned {response.status_code}" - }) - - # Try fallback without date filters if we get 404 - if response.status_code == 404: - print("Trying fallback without date filters...") - fallback_params = { - "resultsPerPage": 5, - "startIndex": 0 - } - fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15) - result["fallback_status_code"] = fallback_response.status_code - - if fallback_response.status_code == 200: - fallback_data = fallback_response.json() - result.update({ - "fallback_success": True, - "fallback_total_results": fallback_data.get("totalResults", 0), - "message": "NVD API works without date filters" - }) - - return result - - except Exception as e: - print(f"NVD API test error: {str(e)}") - return { - "status": "error", - "message": f"Failed to connect to NVD API: {str(e)}" - } - -@app.get("/api/stats") -async def get_stats(db: Session = Depends(get_db)): - total_cves = db.query(CVE).count() - total_rules = db.query(SigmaRule).count() - recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count() - - # Enhanced stats with bulk processing info - bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count() - cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count() - nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count() - - return { - "total_cves": total_cves, - "total_sigma_rules": total_rules, - "recent_cves_7_days": recent_cves, - "bulk_processed_cves": bulk_processed_cves, - "cves_with_pocs": cves_with_pocs, - "nomi_sec_rules": nomi_sec_rules, - "poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0, - "nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0 - } - -# New bulk processing endpoints -@app.post("/api/bulk-seed") -async def start_bulk_seed(background_tasks: BackgroundTasks, - request: BulkSeedRequest, - db: Session = Depends(get_db)): - """Start bulk seeding process""" - - async def bulk_seed_task(): - try: - from bulk_seeder import BulkSeeder - seeder = BulkSeeder(db) - result = await seeder.full_bulk_seed( - start_year=request.start_year, - end_year=request.end_year, - skip_nvd=request.skip_nvd, - skip_nomi_sec=request.skip_nomi_sec - ) - logger.info(f"Bulk seed completed: {result}") - except Exception as e: - logger.error(f"Bulk seed failed: {e}") - import traceback - traceback.print_exc() - - background_tasks.add_task(bulk_seed_task) - - return { - "message": "Bulk seeding process started", - "status": "started", - "start_year": request.start_year, - "end_year": request.end_year or datetime.now().year, - "skip_nvd": request.skip_nvd, - "skip_nomi_sec": request.skip_nomi_sec - } - -@app.post("/api/incremental-update") -async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): - """Start incremental update process""" - - async def incremental_update_task(): - try: - from bulk_seeder import BulkSeeder - seeder = BulkSeeder(db) - result = await seeder.incremental_update() - logger.info(f"Incremental update completed: {result}") - except Exception as e: - logger.error(f"Incremental update failed: {e}") - import traceback - traceback.print_exc() - - background_tasks.add_task(incremental_update_task) - - return { - "message": "Incremental update process started", - "status": "started" - } - -@app.post("/api/sync-nomi-sec") -async def sync_nomi_sec(background_tasks: BackgroundTasks, - request: NomiSecSyncRequest, - db: Session = Depends(get_db)): - """Synchronize nomi-sec PoC data""" - - # Create job record - job = BulkProcessingJob( - job_type='nomi_sec_sync', - status='pending', - job_metadata={ - 'cve_id': request.cve_id, - 'batch_size': request.batch_size - } - ) - db.add(job) - db.commit() - db.refresh(job) - - job_id = str(job.id) - running_jobs[job_id] = job - job_cancellation_flags[job_id] = False - - async def sync_task(): - try: - job.status = 'running' - job.started_at = datetime.utcnow() - db.commit() - - from nomi_sec_client import NomiSecClient - client = NomiSecClient(db) - - if request.cve_id: - # Sync specific CVE - if job_cancellation_flags.get(job_id, False): - logger.info(f"Job {job_id} cancelled before starting") - return - - result = await client.sync_cve_pocs(request.cve_id) - logger.info(f"Nomi-sec sync for {request.cve_id}: {result}") - else: - # Sync all CVEs with cancellation support - result = await client.bulk_sync_all_cves( - batch_size=request.batch_size, - cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) - ) - logger.info(f"Nomi-sec bulk sync completed: {result}") - - # Update job status if not cancelled - if not job_cancellation_flags.get(job_id, False): - job.status = 'completed' - job.completed_at = datetime.utcnow() - db.commit() - - except Exception as e: - if not job_cancellation_flags.get(job_id, False): - job.status = 'failed' - job.error_message = str(e) - job.completed_at = datetime.utcnow() - db.commit() - - logger.error(f"Nomi-sec sync failed: {e}") - import traceback - traceback.print_exc() - finally: - # Clean up tracking - running_jobs.pop(job_id, None) - job_cancellation_flags.pop(job_id, None) - - background_tasks.add_task(sync_task) - - return { - "message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size - } - -@app.post("/api/sync-github-pocs") -async def sync_github_pocs(background_tasks: BackgroundTasks, - request: GitHubPoCSyncRequest, - db: Session = Depends(get_db)): - """Synchronize GitHub PoC data""" - - # Create job record - job = BulkProcessingJob( - job_type='github_poc_sync', - status='pending', - job_metadata={ - 'cve_id': request.cve_id, - 'batch_size': request.batch_size - } - ) - db.add(job) - db.commit() - db.refresh(job) - - job_id = str(job.id) - running_jobs[job_id] = job - job_cancellation_flags[job_id] = False - - async def sync_task(): - try: - job.status = 'running' - job.started_at = datetime.utcnow() - db.commit() - - client = GitHubPoCClient(db) - - if request.cve_id: - # Sync specific CVE - if job_cancellation_flags.get(job_id, False): - logger.info(f"Job {job_id} cancelled before starting") - return - - result = await client.sync_cve_pocs(request.cve_id) - logger.info(f"GitHub PoC sync for {request.cve_id}: {result}") - else: - # Sync all CVEs with cancellation support - result = await client.bulk_sync_all_cves(batch_size=request.batch_size) - logger.info(f"GitHub PoC bulk sync completed: {result}") - - # Update job status if not cancelled - if not job_cancellation_flags.get(job_id, False): - job.status = 'completed' - job.completed_at = datetime.utcnow() - db.commit() - - except Exception as e: - if not job_cancellation_flags.get(job_id, False): - job.status = 'failed' - job.error_message = str(e) - job.completed_at = datetime.utcnow() - db.commit() - - logger.error(f"GitHub PoC sync failed: {e}") - import traceback - traceback.print_exc() - finally: - # Clean up tracking - running_jobs.pop(job_id, None) - job_cancellation_flags.pop(job_id, None) - - background_tasks.add_task(sync_task) - - return { - "message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size - } - -@app.post("/api/sync-exploitdb") -async def sync_exploitdb(background_tasks: BackgroundTasks, - request: ExploitDBSyncRequest, - db: Session = Depends(get_db)): - """Synchronize ExploitDB data from git mirror""" - - # Create job record - job = BulkProcessingJob( - job_type='exploitdb_sync', - status='pending', - job_metadata={ - 'cve_id': request.cve_id, - 'batch_size': request.batch_size - } - ) - db.add(job) - db.commit() - db.refresh(job) - - job_id = str(job.id) - running_jobs[job_id] = job - job_cancellation_flags[job_id] = False - - async def sync_task(): - # Create a new database session for the background task - task_db = SessionLocal() - try: - # Get the job in the new session - task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() - if not task_job: - logger.error(f"Job {job_id} not found in task session") - return - - task_job.status = 'running' - task_job.started_at = datetime.utcnow() - task_db.commit() - - from exploitdb_client_local import ExploitDBLocalClient - client = ExploitDBLocalClient(task_db) - - if request.cve_id: - # Sync specific CVE - if job_cancellation_flags.get(job_id, False): - logger.info(f"Job {job_id} cancelled before starting") - return - - result = await client.sync_cve_exploits(request.cve_id) - logger.info(f"ExploitDB sync for {request.cve_id}: {result}") - else: - # Sync all CVEs with cancellation support - result = await client.bulk_sync_exploitdb( - batch_size=request.batch_size, - cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) - ) - logger.info(f"ExploitDB bulk sync completed: {result}") - - # Update job status if not cancelled - if not job_cancellation_flags.get(job_id, False): - task_job.status = 'completed' - task_job.completed_at = datetime.utcnow() - task_db.commit() - - except Exception as e: - if not job_cancellation_flags.get(job_id, False): - # Get the job again in case it was modified - task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() - if task_job: - task_job.status = 'failed' - task_job.error_message = str(e) - task_job.completed_at = datetime.utcnow() - task_db.commit() - - logger.error(f"ExploitDB sync failed: {e}") - import traceback - traceback.print_exc() - finally: - # Clean up tracking and close the task session - running_jobs.pop(job_id, None) - job_cancellation_flags.pop(job_id, None) - task_db.close() - - background_tasks.add_task(sync_task) - - return { - "message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size - } - -@app.post("/api/sync-cisa-kev") -async def sync_cisa_kev(background_tasks: BackgroundTasks, - request: CISAKEVSyncRequest, - db: Session = Depends(get_db)): - """Synchronize CISA Known Exploited Vulnerabilities data""" - - # Create job record - job = BulkProcessingJob( - job_type='cisa_kev_sync', - status='pending', - job_metadata={ - 'cve_id': request.cve_id, - 'batch_size': request.batch_size - } - ) - db.add(job) - db.commit() - db.refresh(job) - - job_id = str(job.id) - running_jobs[job_id] = job - job_cancellation_flags[job_id] = False - - async def sync_task(): - # Create a new database session for the background task - task_db = SessionLocal() - try: - # Get the job in the new session - task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() - if not task_job: - logger.error(f"Job {job_id} not found in task session") - return - - task_job.status = 'running' - task_job.started_at = datetime.utcnow() - task_db.commit() - - from cisa_kev_client import CISAKEVClient - client = CISAKEVClient(task_db) - - if request.cve_id: - # Sync specific CVE - if job_cancellation_flags.get(job_id, False): - logger.info(f"Job {job_id} cancelled before starting") - return - - result = await client.sync_cve_kev_data(request.cve_id) - logger.info(f"CISA KEV sync for {request.cve_id}: {result}") - else: - # Sync all CVEs with cancellation support - result = await client.bulk_sync_kev_data( - batch_size=request.batch_size, - cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) - ) - logger.info(f"CISA KEV bulk sync completed: {result}") - - # Update job status if not cancelled - if not job_cancellation_flags.get(job_id, False): - task_job.status = 'completed' - task_job.completed_at = datetime.utcnow() - task_db.commit() - - except Exception as e: - if not job_cancellation_flags.get(job_id, False): - # Get the job again in case it was modified - task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() - if task_job: - task_job.status = 'failed' - task_job.error_message = str(e) - task_job.completed_at = datetime.utcnow() - task_db.commit() - - logger.error(f"CISA KEV sync failed: {e}") - import traceback - traceback.print_exc() - finally: - # Clean up tracking and close the task session - running_jobs.pop(job_id, None) - job_cancellation_flags.pop(job_id, None) - task_db.close() - - background_tasks.add_task(sync_task) - - return { - "message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size - } - -@app.post("/api/sync-references") -async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): - """Start reference data synchronization""" - - try: - from reference_client import ReferenceClient - client = ReferenceClient(db) - - # Create job ID - job_id = str(uuid.uuid4()) - - # Add job to tracking - running_jobs[job_id] = { - 'type': 'reference_sync', - 'status': 'running', - 'cve_id': request.cve_id, - 'batch_size': request.batch_size, - 'max_cves': request.max_cves, - 'force_resync': request.force_resync, - 'started_at': datetime.utcnow() - } - - # Create cancellation flag - job_cancellation_flags[job_id] = False - - async def sync_task(): - try: - if request.cve_id: - # Single CVE sync - result = await client.sync_cve_references(request.cve_id) - running_jobs[job_id]['result'] = result - running_jobs[job_id]['status'] = 'completed' - else: - # Bulk sync - result = await client.bulk_sync_references( - batch_size=request.batch_size, - max_cves=request.max_cves, - force_resync=request.force_resync, - cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) - ) - running_jobs[job_id]['result'] = result - running_jobs[job_id]['status'] = 'completed' - - running_jobs[job_id]['completed_at'] = datetime.utcnow() - - except Exception as e: - logger.error(f"Reference sync task failed: {e}") - running_jobs[job_id]['status'] = 'failed' - running_jobs[job_id]['error'] = str(e) - running_jobs[job_id]['completed_at'] = datetime.utcnow() - finally: - # Clean up cancellation flag - job_cancellation_flags.pop(job_id, None) - - background_tasks.add_task(sync_task) - - return { - "message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), - "status": "started", - "job_id": job_id, - "cve_id": request.cve_id, - "batch_size": request.batch_size, - "max_cves": request.max_cves, - "force_resync": request.force_resync - } - - except Exception as e: - logger.error(f"Failed to start reference sync: {e}") - raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}") - -@app.get("/api/reference-stats") -async def get_reference_stats(db: Session = Depends(get_db)): - """Get reference synchronization statistics""" - - try: - from reference_client import ReferenceClient - client = ReferenceClient(db) - - # Get sync status - status = await client.get_reference_sync_status() - - # Get quality distribution from reference data - quality_distribution = {} - from sqlalchemy import text - cves_with_references = db.query(CVE).filter( - text("reference_data::text LIKE '%\"reference_analysis\"%'") - ).all() - - for cve in cves_with_references: - if cve.reference_data and 'reference_analysis' in cve.reference_data: - ref_analysis = cve.reference_data['reference_analysis'] - high_conf_refs = ref_analysis.get('high_confidence_references', 0) - total_refs = ref_analysis.get('reference_count', 0) - - if total_refs > 0: - quality_ratio = high_conf_refs / total_refs - if quality_ratio >= 0.8: - quality_tier = 'excellent' - elif quality_ratio >= 0.6: - quality_tier = 'good' - elif quality_ratio >= 0.4: - quality_tier = 'fair' - else: - quality_tier = 'poor' - - quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1 - - # Get reference type distribution - reference_type_distribution = {} - for cve in cves_with_references: - if cve.reference_data and 'reference_analysis' in cve.reference_data: - ref_analysis = cve.reference_data['reference_analysis'] - ref_types = ref_analysis.get('reference_types', []) - for ref_type in ref_types: - reference_type_distribution[ref_type] = reference_type_distribution.get(ref_type, 0) + 1 - - return { - 'reference_sync_status': status, - 'quality_distribution': quality_distribution, - 'reference_type_distribution': reference_type_distribution, - 'total_with_reference_analysis': len(cves_with_references), - 'source': 'reference_extraction' - } - - except Exception as e: - logger.error(f"Failed to get reference stats: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get reference stats: {str(e)}") - -@app.get("/api/exploitdb-stats") -async def get_exploitdb_stats(db: Session = Depends(get_db)): - """Get ExploitDB-related statistics""" - - try: - from exploitdb_client_local import ExploitDBLocalClient - client = ExploitDBLocalClient(db) - - # Get sync status - status = await client.get_exploitdb_sync_status() - - # Get quality distribution from ExploitDB data - quality_distribution = {} - from sqlalchemy import text - cves_with_exploitdb = db.query(CVE).filter( - text("poc_data::text LIKE '%\"exploitdb\"%'") - ).all() - - for cve in cves_with_exploitdb: - if cve.poc_data and 'exploitdb' in cve.poc_data: - exploits = cve.poc_data['exploitdb'].get('exploits', []) - for exploit in exploits: - quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown') - quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1 - - # Get category distribution - category_distribution = {} - for cve in cves_with_exploitdb: - if cve.poc_data and 'exploitdb' in cve.poc_data: - exploits = cve.poc_data['exploitdb'].get('exploits', []) - for exploit in exploits: - category = exploit.get('category', 'unknown') - category_distribution[category] = category_distribution.get(category, 0) + 1 - - return { - "exploitdb_sync_status": status, - "quality_distribution": quality_distribution, - "category_distribution": category_distribution, - "total_exploitdb_cves": len(cves_with_exploitdb), - "total_exploits": sum( - len(cve.poc_data.get('exploitdb', {}).get('exploits', [])) - for cve in cves_with_exploitdb - if cve.poc_data and 'exploitdb' in cve.poc_data - ) - } - - except Exception as e: - logger.error(f"Error getting ExploitDB stats: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/api/github-poc-stats") -async def get_github_poc_stats(db: Session = Depends(get_db)): - """Get GitHub PoC-related statistics""" - - try: - # Get basic statistics - github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count() - cves_with_github_pocs = db.query(CVE).filter( - CVE.poc_data.isnot(None), # Check if poc_data exists - func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' - ).count() - - # Get quality distribution - quality_distribution = {} - try: - quality_results = db.query( - func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'), - func.count().label('count') - ).filter( - CVE.poc_data.isnot(None), - func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' - ).group_by('tier').all() - - for tier, count in quality_results: - if tier: - quality_distribution[tier] = count - except Exception as e: - logger.warning(f"Error getting quality distribution: {e}") - quality_distribution = {} - - # Calculate average quality score - try: - avg_quality = db.query( - func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer)) - ).filter( - CVE.poc_data.isnot(None), - func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' - ).scalar() or 0 - except Exception as e: - logger.warning(f"Error calculating average quality: {e}") - avg_quality = 0 - - return { - 'github_poc_rules': github_poc_rules, - 'cves_with_github_pocs': cves_with_github_pocs, - 'quality_distribution': quality_distribution, - 'average_quality_score': float(avg_quality) if avg_quality else 0, - 'source': 'github_poc' - } - except Exception as e: - logger.error(f"Error getting GitHub PoC stats: {e}") - return {"error": str(e)} - -@app.get("/api/github-poc-status") -async def get_github_poc_status(db: Session = Depends(get_db)): - """Get GitHub PoC data availability status""" - - try: - client = GitHubPoCClient(db) - - # Check if GitHub PoC data is available - github_poc_data = client.load_github_poc_data() - - return { - 'github_poc_data_available': len(github_poc_data) > 0, - 'total_cves_with_pocs': len(github_poc_data), - 'sample_cve_ids': list(github_poc_data.keys())[:10], # First 10 CVE IDs - 'data_path': str(client.github_poc_path), - 'path_exists': client.github_poc_path.exists() - } - except Exception as e: - logger.error(f"Error checking GitHub PoC status: {e}") - return {"error": str(e)} - -@app.get("/api/cisa-kev-stats") -async def get_cisa_kev_stats(db: Session = Depends(get_db)): - """Get CISA KEV-related statistics""" - - try: - from cisa_kev_client import CISAKEVClient - client = CISAKEVClient(db) - - # Get sync status - status = await client.get_kev_sync_status() - - # Get threat level distribution from CISA KEV data - threat_level_distribution = {} - from sqlalchemy import text - cves_with_kev = db.query(CVE).filter( - text("poc_data::text LIKE '%\"cisa_kev\"%'") - ).all() - - for cve in cves_with_kev: - if cve.poc_data and 'cisa_kev' in cve.poc_data: - vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) - threat_level = vuln_data.get('threat_level', 'unknown') - threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1 - - # Get vulnerability category distribution - category_distribution = {} - for cve in cves_with_kev: - if cve.poc_data and 'cisa_kev' in cve.poc_data: - vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) - category = vuln_data.get('vulnerability_category', 'unknown') - category_distribution[category] = category_distribution.get(category, 0) + 1 - - # Get ransomware usage statistics - ransomware_stats = {'known': 0, 'unknown': 0} - for cve in cves_with_kev: - if cve.poc_data and 'cisa_kev' in cve.poc_data: - vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) - ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower() - if ransomware_use == 'known': - ransomware_stats['known'] += 1 - else: - ransomware_stats['unknown'] += 1 - - # Calculate average threat score - threat_scores = [] - for cve in cves_with_kev: - if cve.poc_data and 'cisa_kev' in cve.poc_data: - vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) - threat_score = vuln_data.get('threat_score', 0) - if threat_score: - threat_scores.append(threat_score) - - avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0 - - return { - "cisa_kev_sync_status": status, - "threat_level_distribution": threat_level_distribution, - "category_distribution": category_distribution, - "ransomware_stats": ransomware_stats, - "average_threat_score": round(avg_threat_score, 2), - "total_kev_cves": len(cves_with_kev), - "total_with_threat_scores": len(threat_scores) - } - - except Exception as e: - logger.error(f"Error getting CISA KEV stats: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/api/bulk-jobs") -async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)): - """Get bulk processing job status""" - - jobs = db.query(BulkProcessingJob).order_by( - BulkProcessingJob.created_at.desc() - ).limit(limit).all() - - result = [] - for job in jobs: - job_dict = { - 'id': str(job.id), - 'job_type': job.job_type, - 'status': job.status, - 'year': job.year, - 'total_items': job.total_items, - 'processed_items': job.processed_items, - 'failed_items': job.failed_items, - 'error_message': job.error_message, - 'metadata': job.job_metadata, - 'started_at': job.started_at, - 'completed_at': job.completed_at, - 'created_at': job.created_at - } - result.append(job_dict) - - return result - -@app.get("/api/bulk-status") -async def get_bulk_status(db: Session = Depends(get_db)): - """Get comprehensive bulk processing status""" - - try: - from bulk_seeder import BulkSeeder - seeder = BulkSeeder(db) - status = await seeder.get_seeding_status() - return status - except Exception as e: - logger.error(f"Error getting bulk status: {e}") - return {"error": str(e)} - -@app.get("/api/poc-stats") -async def get_poc_stats(db: Session = Depends(get_db)): - """Get PoC-related statistics""" - - try: - from nomi_sec_client import NomiSecClient - client = NomiSecClient(db) - stats = await client.get_sync_status() - - # Additional PoC statistics - high_quality_cves = db.query(CVE).filter( - CVE.poc_count > 0, - func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60 - ).count() - - stats.update({ - 'high_quality_cves': high_quality_cves, - 'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0 - }) - - return stats - except Exception as e: - logger.error(f"Error getting PoC stats: {e}") - return {"error": str(e)} - -@app.get("/api/cve2capec-stats") -async def get_cve2capec_stats(): - """Get CVE2CAPEC MITRE ATT&CK mapping statistics""" - - try: - client = CVE2CAPECClient() - stats = client.get_stats() - - return { - "status": "success", - "data": stats, - "description": "CVE to MITRE ATT&CK technique mappings from CVE2CAPEC repository" - } - except Exception as e: - logger.error(f"Error getting CVE2CAPEC stats: {e}") - return {"error": str(e)} - -@app.post("/api/regenerate-rules") -async def regenerate_sigma_rules(background_tasks: BackgroundTasks, - request: RuleRegenRequest, - db: Session = Depends(get_db)): - """Regenerate SIGMA rules using enhanced nomi-sec data""" - - async def regenerate_task(): - try: - from enhanced_sigma_generator import EnhancedSigmaGenerator - generator = EnhancedSigmaGenerator(db) - - # Get CVEs with PoC data - cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all() - - rules_generated = 0 - rules_updated = 0 - - for cve in cves_with_pocs: - # Check if we should regenerate - existing_rule = db.query(SigmaRule).filter( - SigmaRule.cve_id == cve.cve_id - ).first() - - if existing_rule and existing_rule.poc_source == 'nomi_sec' and not request.force: - continue - - # Generate enhanced rule - result = await generator.generate_enhanced_rule(cve) - - if result['success']: - if existing_rule: - rules_updated += 1 - else: - rules_generated += 1 - - logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated") - - except Exception as e: - logger.error(f"Rule regeneration failed: {e}") - import traceback - traceback.print_exc() - - background_tasks.add_task(regenerate_task) - - return { - "message": "SIGMA rule regeneration started", - "status": "started", - "force": request.force - } - -@app.post("/api/llm-enhanced-rules") -async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): - """Generate SIGMA rules using LLM API for enhanced analysis""" - - # Parse request parameters - cve_id = request.get('cve_id') - force = request.get('force', False) - llm_provider = request.get('provider', os.getenv('LLM_PROVIDER')) - llm_model = request.get('model', os.getenv('LLM_MODEL')) - - # Validation - if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id): - raise HTTPException(status_code=400, detail="Invalid CVE ID format") - - async def llm_generation_task(): - """Background task for LLM-enhanced rule generation""" - try: - from enhanced_sigma_generator import EnhancedSigmaGenerator - - generator = EnhancedSigmaGenerator(db, llm_provider, llm_model) - - # Process specific CVE or all CVEs with PoC data - if cve_id: - cve = db.query(CVE).filter(CVE.cve_id == cve_id).first() - if not cve: - logger.error(f"CVE {cve_id} not found") - return - - cves_to_process = [cve] - else: - # Process CVEs with PoC data that either have no rules or force update - query = db.query(CVE).filter(CVE.poc_count > 0) - - if not force: - # Only process CVEs without existing LLM-generated rules - existing_llm_rules = db.query(SigmaRule).filter( - SigmaRule.detection_type.like('llm_%') - ).all() - existing_cve_ids = {rule.cve_id for rule in existing_llm_rules} - cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids] - else: - cves_to_process = query.all() - - logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}") - - rules_generated = 0 - rules_updated = 0 - failures = 0 - - for cve in cves_to_process: - try: - # Check if CVE has sufficient PoC data - if not cve.poc_data or not cve.poc_count: - logger.debug(f"Skipping {cve.cve_id} - no PoC data") - continue - - # Generate LLM-enhanced rule - result = await generator.generate_enhanced_rule(cve, use_llm=True) - - if result.get('success'): - if result.get('updated'): - rules_updated += 1 - else: - rules_generated += 1 - - logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}") - else: - failures += 1 - logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}") - - except Exception as e: - failures += 1 - logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") - continue - - logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures") - - except Exception as e: - logger.error(f"LLM-enhanced rule generation failed: {e}") - import traceback - traceback.print_exc() - - background_tasks.add_task(llm_generation_task) - - return { - "message": "LLM-enhanced SIGMA rule generation started", - "status": "started", - "cve_id": cve_id, - "force": force, - "provider": llm_provider, - "model": llm_model, - "note": "Requires appropriate LLM API key to be set" - } - -@app.get("/api/llm-status") -async def get_llm_status(): - """Check LLM API availability status""" - try: - from llm_client import LLMClient - - # Get current provider configuration - provider = os.getenv('LLM_PROVIDER') - model = os.getenv('LLM_MODEL') - - client = LLMClient(provider=provider, model=model) - provider_info = client.get_provider_info() - - # Get all available providers - all_providers = LLMClient.get_available_providers() - - return { - "current_provider": provider_info, - "available_providers": all_providers, - "status": "ready" if client.is_available() else "unavailable" - } - except Exception as e: - logger.error(f"Error checking LLM status: {e}") - return { - "current_provider": {"provider": "unknown", "available": False}, - "available_providers": [], - "status": "error", - "error": str(e) - } - -@app.post("/api/llm-switch") -async def switch_llm_provider(request: dict): - """Switch LLM provider and model""" - try: - from llm_client import LLMClient - - provider = request.get('provider') - model = request.get('model') - - if not provider: - raise HTTPException(status_code=400, detail="Provider is required") - - # Validate provider - if provider not in LLMClient.SUPPORTED_PROVIDERS: - raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}") - - # Test the new configuration - client = LLMClient(provider=provider, model=model) - - if not client.is_available(): - raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured") - - # Update environment variables (note: this only affects the current session) - os.environ['LLM_PROVIDER'] = provider - if model: - os.environ['LLM_MODEL'] = model - - provider_info = client.get_provider_info() - - return { - "message": f"Switched to {provider}", - "provider_info": provider_info, - "status": "success" - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error switching LLM provider: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/cancel-job/{job_id}") -async def cancel_job(job_id: str, db: Session = Depends(get_db)): - """Cancel a running job""" - try: - # Find the job in the database - job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first() - if not job: - raise HTTPException(status_code=404, detail="Job not found") - - if job.status not in ['pending', 'running']: - raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}") - - # Set cancellation flag - job_cancellation_flags[job_id] = True - - # Update job status - job.status = 'cancelled' - job.cancelled_at = datetime.utcnow() - job.error_message = "Job cancelled by user" - - db.commit() - - logger.info(f"Job {job_id} cancellation requested") - - return { - "message": f"Job {job_id} cancellation requested", - "status": "cancelled", - "job_id": job_id - } - except HTTPException: - raise - except Exception as e: - logger.error(f"Error cancelling job {job_id}: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/api/running-jobs") -async def get_running_jobs(db: Session = Depends(get_db)): - """Get all currently running jobs""" - try: - jobs = db.query(BulkProcessingJob).filter( - BulkProcessingJob.status.in_(['pending', 'running']) - ).order_by(BulkProcessingJob.created_at.desc()).all() - - result = [] - for job in jobs: - result.append({ - 'id': str(job.id), - 'job_type': job.job_type, - 'status': job.status, - 'year': job.year, - 'total_items': job.total_items, - 'processed_items': job.processed_items, - 'failed_items': job.failed_items, - 'error_message': job.error_message, - 'started_at': job.started_at, - 'created_at': job.created_at, - 'can_cancel': job.status in ['pending', 'running'] - }) - - return result - except Exception as e: - logger.error(f"Error getting running jobs: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/ollama-pull-model") -async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks): - """Pull an Ollama model""" - try: - from llm_client import LLMClient - - model = request.get('model') - if not model: - raise HTTPException(status_code=400, detail="Model name is required") - - # Create a background task to pull the model - def pull_model_task(): - try: - client = LLMClient(provider='ollama', model=model) - base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') - - if client._pull_ollama_model(base_url, model): - logger.info(f"Successfully pulled Ollama model: {model}") - else: - logger.error(f"Failed to pull Ollama model: {model}") - except Exception as e: - logger.error(f"Error in model pull task: {e}") - - background_tasks.add_task(pull_model_task) - - return { - "message": f"Started pulling model {model}", - "status": "started", - "model": model - } - - except Exception as e: - logger.error(f"Error starting model pull: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/api/ollama-models") -async def get_ollama_models(): - """Get available Ollama models""" - try: - from llm_client import LLMClient - - client = LLMClient(provider='ollama') - available_models = client._get_ollama_available_models() - - return { - "available_models": available_models, - "total_models": len(available_models), - "status": "success" - } - - except Exception as e: - logger.error(f"Error getting Ollama models: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -# ============================================================================ -# SCHEDULER ENDPOINTS -# ============================================================================ - -class SchedulerControlRequest(BaseModel): - action: str # 'start', 'stop', 'restart' - -class JobControlRequest(BaseModel): - job_name: str - action: str # 'enable', 'disable', 'trigger' - -class UpdateScheduleRequest(BaseModel): - job_name: str - schedule: str # Cron expression - -@app.get("/api/scheduler/status") -async def get_scheduler_status(): - """Get scheduler status and job information""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - status = scheduler.get_job_status() - - return { - "scheduler_status": status, - "timestamp": datetime.utcnow().isoformat() - } - - except Exception as e: - logger.error(f"Error getting scheduler status: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/control") -async def control_scheduler(request: SchedulerControlRequest): - """Control scheduler (start/stop/restart)""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - - if request.action == 'start': - scheduler.start() - message = "Scheduler started" - elif request.action == 'stop': - scheduler.stop() - message = "Scheduler stopped" - elif request.action == 'restart': - scheduler.stop() - scheduler.start() - message = "Scheduler restarted" - else: - raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") - - return { - "message": message, - "action": request.action, - "timestamp": datetime.utcnow().isoformat() - } - - except Exception as e: - logger.error(f"Error controlling scheduler: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/job/control") -async def control_job(request: JobControlRequest): - """Control individual jobs (enable/disable/trigger)""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - - if request.action == 'enable': - success = scheduler.enable_job(request.job_name) - message = f"Job {request.job_name} enabled" if success else f"Job {request.job_name} not found" - elif request.action == 'disable': - success = scheduler.disable_job(request.job_name) - message = f"Job {request.job_name} disabled" if success else f"Job {request.job_name} not found" - elif request.action == 'trigger': - success = scheduler.trigger_job(request.job_name) - message = f"Job {request.job_name} triggered" if success else f"Failed to trigger job {request.job_name}" - else: - raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") - - return { - "message": message, - "job_name": request.job_name, - "action": request.action, - "success": success, - "timestamp": datetime.utcnow().isoformat() - } - - except Exception as e: - logger.error(f"Error controlling job: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/job/schedule") -async def update_job_schedule(request: UpdateScheduleRequest): - """Update job schedule""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - success = scheduler.update_job_schedule(request.job_name, request.schedule) - - if success: - # Get updated job info - job_status = scheduler.get_job_status(request.job_name) - return { - "message": f"Schedule updated for job {request.job_name}", - "job_name": request.job_name, - "new_schedule": request.schedule, - "next_run": job_status.get("next_run"), - "success": True, - "timestamp": datetime.utcnow().isoformat() - } - else: - raise HTTPException(status_code=400, detail=f"Failed to update schedule for job {request.job_name}") - - except Exception as e: - logger.error(f"Error updating job schedule: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/api/scheduler/job/{job_name}") -async def get_job_status(job_name: str): - """Get status of a specific job""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - status = scheduler.get_job_status(job_name) - - if "error" in status: - raise HTTPException(status_code=404, detail=status["error"]) - - return { - "job_status": status, - "timestamp": datetime.utcnow().isoformat() - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error getting job status: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@app.post("/api/scheduler/reload") -async def reload_scheduler_config(): - """Reload scheduler configuration from file""" - try: - from job_scheduler import get_scheduler - - scheduler = get_scheduler() - success = scheduler.reload_config() - - if success: - return { - "message": "Scheduler configuration reloaded successfully", - "success": True, - "timestamp": datetime.utcnow().isoformat() - } - else: - raise HTTPException(status_code=500, detail="Failed to reload configuration") - - except Exception as e: - logger.error(f"Error reloading scheduler config: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/routers/bulk_operations.py b/backend/routers/bulk_operations.py index 7e971d8..8a221ae 100644 --- a/backend/routers/bulk_operations.py +++ b/backend/routers/bulk_operations.py @@ -1,14 +1,23 @@ from typing import List, Optional from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks from sqlalchemy.orm import Session +from sqlalchemy import func, text +import uuid +from datetime import datetime +import logging -from config.database import get_db +from config.database import get_db, SessionLocal from models import BulkProcessingJob, CVE, SigmaRule from schemas import BulkSeedRequest, NomiSecSyncRequest, GitHubPoCSyncRequest, ExploitDBSyncRequest, CISAKEVSyncRequest, ReferenceSyncRequest from services import CVEService, SigmaRuleService +# Import global job tracking from main.py +import main + router = APIRouter(prefix="/api", tags=["bulk-operations"]) +logger = logging.getLogger(__name__) + @router.post("/bulk-seed") async def bulk_seed(request: BulkSeedRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): @@ -117,4 +126,657 @@ async def get_poc_stats(db: Session = Depends(get_db)): "total_rules": total_rules, "exploit_based_rules": exploit_based_rules, "exploit_based_percentage": round((exploit_based_rules / total_rules * 100), 2) if total_rules > 0 else 0 - } \ No newline at end of file + } + + +@router.post("/sync-nomi-sec") +async def sync_nomi_sec(background_tasks: BackgroundTasks, + request: NomiSecSyncRequest, + db: Session = Depends(get_db)): + """Synchronize nomi-sec PoC data""" + + # Create job record + job = BulkProcessingJob( + job_type='nomi_sec_sync', + status='pending', + job_metadata={ + 'cve_id': request.cve_id, + 'batch_size': request.batch_size + } + ) + db.add(job) + db.commit() + db.refresh(job) + + job_id = str(job.id) + main.running_jobs[job_id] = job + main.job_cancellation_flags[job_id] = False + + async def sync_task(): + try: + job.status = 'running' + job.started_at = datetime.utcnow() + db.commit() + + from nomi_sec_client import NomiSecClient + client = NomiSecClient(db) + + if request.cve_id: + # Sync specific CVE + if main.job_cancellation_flags.get(job_id, False): + logger.info(f"Job {job_id} cancelled before starting") + return + + result = await client.sync_cve_pocs(request.cve_id) + logger.info(f"Nomi-sec sync for {request.cve_id}: {result}") + else: + # Sync all CVEs with cancellation support + result = await client.bulk_sync_all_cves( + batch_size=request.batch_size, + cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) + ) + logger.info(f"Nomi-sec bulk sync completed: {result}") + + # Update job status if not cancelled + if not main.job_cancellation_flags.get(job_id, False): + job.status = 'completed' + job.completed_at = datetime.utcnow() + db.commit() + + except Exception as e: + if not main.job_cancellation_flags.get(job_id, False): + job.status = 'failed' + job.error_message = str(e) + job.completed_at = datetime.utcnow() + db.commit() + + logger.error(f"Nomi-sec sync failed: {e}") + import traceback + traceback.print_exc() + finally: + # Clean up tracking + main.running_jobs.pop(job_id, None) + main.job_cancellation_flags.pop(job_id, None) + + background_tasks.add_task(sync_task) + + return { + "message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "job_id": job_id, + "cve_id": request.cve_id, + "batch_size": request.batch_size + } + + +@router.post("/sync-github-pocs") +async def sync_github_pocs(background_tasks: BackgroundTasks, + request: GitHubPoCSyncRequest, + db: Session = Depends(get_db)): + """Synchronize GitHub PoC data""" + + # Create job record + job = BulkProcessingJob( + job_type='github_poc_sync', + status='pending', + job_metadata={ + 'cve_id': request.cve_id, + 'batch_size': request.batch_size + } + ) + db.add(job) + db.commit() + db.refresh(job) + + job_id = str(job.id) + main.running_jobs[job_id] = job + main.job_cancellation_flags[job_id] = False + + async def sync_task(): + try: + job.status = 'running' + job.started_at = datetime.utcnow() + db.commit() + + from mcdevitt_poc_client import GitHubPoCClient + client = GitHubPoCClient(db) + + if request.cve_id: + # Sync specific CVE + if main.job_cancellation_flags.get(job_id, False): + logger.info(f"Job {job_id} cancelled before starting") + return + + result = await client.sync_cve_pocs(request.cve_id) + logger.info(f"GitHub PoC sync for {request.cve_id}: {result}") + else: + # Sync all CVEs with cancellation support + result = await client.bulk_sync_all_cves(batch_size=request.batch_size) + logger.info(f"GitHub PoC bulk sync completed: {result}") + + # Update job status if not cancelled + if not main.job_cancellation_flags.get(job_id, False): + job.status = 'completed' + job.completed_at = datetime.utcnow() + db.commit() + + except Exception as e: + if not main.job_cancellation_flags.get(job_id, False): + job.status = 'failed' + job.error_message = str(e) + job.completed_at = datetime.utcnow() + db.commit() + + logger.error(f"GitHub PoC sync failed: {e}") + import traceback + traceback.print_exc() + finally: + # Clean up tracking + main.running_jobs.pop(job_id, None) + main.job_cancellation_flags.pop(job_id, None) + + background_tasks.add_task(sync_task) + + return { + "message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "job_id": job_id, + "cve_id": request.cve_id, + "batch_size": request.batch_size + } + + +@router.post("/sync-exploitdb") +async def sync_exploitdb(background_tasks: BackgroundTasks, + request: ExploitDBSyncRequest, + db: Session = Depends(get_db)): + """Synchronize ExploitDB data from git mirror""" + + # Create job record + job = BulkProcessingJob( + job_type='exploitdb_sync', + status='pending', + job_metadata={ + 'cve_id': request.cve_id, + 'batch_size': request.batch_size + } + ) + db.add(job) + db.commit() + db.refresh(job) + + job_id = str(job.id) + main.running_jobs[job_id] = job + main.job_cancellation_flags[job_id] = False + + async def sync_task(): + # Create a new database session for the background task + task_db = SessionLocal() + try: + # Get the job in the new session + task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() + if not task_job: + logger.error(f"Job {job_id} not found in task session") + return + + task_job.status = 'running' + task_job.started_at = datetime.utcnow() + task_db.commit() + + from exploitdb_client_local import ExploitDBLocalClient + client = ExploitDBLocalClient(task_db) + + if request.cve_id: + # Sync specific CVE + if main.job_cancellation_flags.get(job_id, False): + logger.info(f"Job {job_id} cancelled before starting") + return + + result = await client.sync_cve_exploits(request.cve_id) + logger.info(f"ExploitDB sync for {request.cve_id}: {result}") + else: + # Sync all CVEs with cancellation support + result = await client.bulk_sync_exploitdb( + batch_size=request.batch_size, + cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) + ) + logger.info(f"ExploitDB bulk sync completed: {result}") + + # Update job status if not cancelled + if not main.job_cancellation_flags.get(job_id, False): + task_job.status = 'completed' + task_job.completed_at = datetime.utcnow() + task_db.commit() + + except Exception as e: + if not main.job_cancellation_flags.get(job_id, False): + # Get the job again in case it was modified + task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() + if task_job: + task_job.status = 'failed' + task_job.error_message = str(e) + task_job.completed_at = datetime.utcnow() + task_db.commit() + + logger.error(f"ExploitDB sync failed: {e}") + import traceback + traceback.print_exc() + finally: + # Clean up tracking and close the task session + main.running_jobs.pop(job_id, None) + main.job_cancellation_flags.pop(job_id, None) + task_db.close() + + background_tasks.add_task(sync_task) + + return { + "message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "job_id": job_id, + "cve_id": request.cve_id, + "batch_size": request.batch_size + } + + +@router.post("/sync-cisa-kev") +async def sync_cisa_kev(background_tasks: BackgroundTasks, + request: CISAKEVSyncRequest, + db: Session = Depends(get_db)): + """Synchronize CISA Known Exploited Vulnerabilities data""" + + # Create job record + job = BulkProcessingJob( + job_type='cisa_kev_sync', + status='pending', + job_metadata={ + 'cve_id': request.cve_id, + 'batch_size': request.batch_size + } + ) + db.add(job) + db.commit() + db.refresh(job) + + job_id = str(job.id) + main.running_jobs[job_id] = job + main.job_cancellation_flags[job_id] = False + + async def sync_task(): + # Create a new database session for the background task + task_db = SessionLocal() + try: + # Get the job in the new session + task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() + if not task_job: + logger.error(f"Job {job_id} not found in task session") + return + + task_job.status = 'running' + task_job.started_at = datetime.utcnow() + task_db.commit() + + from cisa_kev_client import CISAKEVClient + client = CISAKEVClient(task_db) + + if request.cve_id: + # Sync specific CVE + if main.job_cancellation_flags.get(job_id, False): + logger.info(f"Job {job_id} cancelled before starting") + return + + result = await client.sync_cve_kev_data(request.cve_id) + logger.info(f"CISA KEV sync for {request.cve_id}: {result}") + else: + # Sync all CVEs with cancellation support + result = await client.bulk_sync_kev_data( + batch_size=request.batch_size, + cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) + ) + logger.info(f"CISA KEV bulk sync completed: {result}") + + # Update job status if not cancelled + if not main.job_cancellation_flags.get(job_id, False): + task_job.status = 'completed' + task_job.completed_at = datetime.utcnow() + task_db.commit() + + except Exception as e: + if not main.job_cancellation_flags.get(job_id, False): + # Get the job again in case it was modified + task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() + if task_job: + task_job.status = 'failed' + task_job.error_message = str(e) + task_job.completed_at = datetime.utcnow() + task_db.commit() + + logger.error(f"CISA KEV sync failed: {e}") + import traceback + traceback.print_exc() + finally: + # Clean up tracking and close the task session + main.running_jobs.pop(job_id, None) + main.job_cancellation_flags.pop(job_id, None) + task_db.close() + + background_tasks.add_task(sync_task) + + return { + "message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "job_id": job_id, + "cve_id": request.cve_id, + "batch_size": request.batch_size + } + + +@router.post("/sync-references") +async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): + """Start reference data synchronization""" + + try: + from reference_client import ReferenceClient + client = ReferenceClient(db) + + # Create job ID + job_id = str(uuid.uuid4()) + + # Add job to tracking + main.running_jobs[job_id] = { + 'type': 'reference_sync', + 'status': 'running', + 'cve_id': request.cve_id, + 'batch_size': request.batch_size, + 'max_cves': request.max_cves, + 'force_resync': request.force_resync, + 'started_at': datetime.utcnow() + } + + # Create cancellation flag + main.job_cancellation_flags[job_id] = False + + async def sync_task(): + try: + if request.cve_id: + # Single CVE sync + result = await client.sync_cve_references(request.cve_id) + main.running_jobs[job_id]['result'] = result + main.running_jobs[job_id]['status'] = 'completed' + else: + # Bulk sync + result = await client.bulk_sync_references( + batch_size=request.batch_size, + max_cves=request.max_cves, + force_resync=request.force_resync, + cancellation_flag=lambda: main.job_cancellation_flags.get(job_id, False) + ) + main.running_jobs[job_id]['result'] = result + main.running_jobs[job_id]['status'] = 'completed' + + main.running_jobs[job_id]['completed_at'] = datetime.utcnow() + + except Exception as e: + logger.error(f"Reference sync task failed: {e}") + main.running_jobs[job_id]['status'] = 'failed' + main.running_jobs[job_id]['error'] = str(e) + main.running_jobs[job_id]['completed_at'] = datetime.utcnow() + finally: + # Clean up cancellation flag + main.job_cancellation_flags.pop(job_id, None) + + background_tasks.add_task(sync_task) + + return { + "message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"), + "status": "started", + "job_id": job_id, + "cve_id": request.cve_id, + "batch_size": request.batch_size, + "max_cves": request.max_cves, + "force_resync": request.force_resync + } + + except Exception as e: + logger.error(f"Failed to start reference sync: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}") + + +@router.get("/exploitdb-stats") +async def get_exploitdb_stats(db: Session = Depends(get_db)): + """Get ExploitDB-related statistics""" + + try: + from exploitdb_client_local import ExploitDBLocalClient + client = ExploitDBLocalClient(db) + + # Get sync status + status = await client.get_exploitdb_sync_status() + + # Get quality distribution from ExploitDB data + quality_distribution = {} + cves_with_exploitdb = db.query(CVE).filter( + text("poc_data::text LIKE '%\"exploitdb\"%'") + ).all() + + for cve in cves_with_exploitdb: + if cve.poc_data and 'exploitdb' in cve.poc_data: + exploits = cve.poc_data['exploitdb'].get('exploits', []) + for exploit in exploits: + quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown') + quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1 + + # Get category distribution + category_distribution = {} + for cve in cves_with_exploitdb: + if cve.poc_data and 'exploitdb' in cve.poc_data: + exploits = cve.poc_data['exploitdb'].get('exploits', []) + for exploit in exploits: + category = exploit.get('category', 'unknown') + category_distribution[category] = category_distribution.get(category, 0) + 1 + + return { + "exploitdb_sync_status": status, + "quality_distribution": quality_distribution, + "category_distribution": category_distribution, + "total_exploitdb_cves": len(cves_with_exploitdb), + "total_exploits": sum( + len(cve.poc_data.get('exploitdb', {}).get('exploits', [])) + for cve in cves_with_exploitdb + if cve.poc_data and 'exploitdb' in cve.poc_data + ) + } + + except Exception as e: + logger.error(f"Error getting ExploitDB stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/github-poc-stats") +async def get_github_poc_stats(db: Session = Depends(get_db)): + """Get GitHub PoC-related statistics""" + + try: + # Get basic statistics + github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count() + cves_with_github_pocs = db.query(CVE).filter( + CVE.poc_data.isnot(None), # Check if poc_data exists + func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' + ).count() + + # Get quality distribution + quality_distribution = {} + try: + quality_results = db.query( + func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'), + func.count().label('count') + ).filter( + CVE.poc_data.isnot(None), + func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' + ).group_by('tier').all() + + for tier, count in quality_results: + if tier: + quality_distribution[tier] = count + except Exception as e: + logger.warning(f"Error getting quality distribution: {e}") + quality_distribution = {} + + # Calculate average quality score + try: + from sqlalchemy import Integer + avg_quality = db.query( + func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer)) + ).filter( + CVE.poc_data.isnot(None), + func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc' + ).scalar() or 0 + except Exception as e: + logger.warning(f"Error calculating average quality: {e}") + avg_quality = 0 + + return { + 'github_poc_rules': github_poc_rules, + 'cves_with_github_pocs': cves_with_github_pocs, + 'quality_distribution': quality_distribution, + 'average_quality_score': float(avg_quality) if avg_quality else 0, + 'source': 'github_poc' + } + except Exception as e: + logger.error(f"Error getting GitHub PoC stats: {e}") + return {"error": str(e)} + + +@router.get("/cisa-kev-stats") +async def get_cisa_kev_stats(db: Session = Depends(get_db)): + """Get CISA KEV-related statistics""" + + try: + from cisa_kev_client import CISAKEVClient + client = CISAKEVClient(db) + + # Get sync status + status = await client.get_kev_sync_status() + + # Get threat level distribution from CISA KEV data + threat_level_distribution = {} + cves_with_kev = db.query(CVE).filter( + text("poc_data::text LIKE '%\"cisa_kev\"%'") + ).all() + + for cve in cves_with_kev: + if cve.poc_data and 'cisa_kev' in cve.poc_data: + vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) + threat_level = vuln_data.get('threat_level', 'unknown') + threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1 + + # Get vulnerability category distribution + category_distribution = {} + for cve in cves_with_kev: + if cve.poc_data and 'cisa_kev' in cve.poc_data: + vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) + category = vuln_data.get('vulnerability_category', 'unknown') + category_distribution[category] = category_distribution.get(category, 0) + 1 + + # Get ransomware usage statistics + ransomware_stats = {'known': 0, 'unknown': 0} + for cve in cves_with_kev: + if cve.poc_data and 'cisa_kev' in cve.poc_data: + vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) + ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower() + if ransomware_use == 'known': + ransomware_stats['known'] += 1 + else: + ransomware_stats['unknown'] += 1 + + # Calculate average threat score + threat_scores = [] + for cve in cves_with_kev: + if cve.poc_data and 'cisa_kev' in cve.poc_data: + vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {}) + threat_score = vuln_data.get('threat_score', 0) + if threat_score: + threat_scores.append(threat_score) + + avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0 + + return { + "cisa_kev_sync_status": status, + "threat_level_distribution": threat_level_distribution, + "category_distribution": category_distribution, + "ransomware_stats": ransomware_stats, + "average_threat_score": round(avg_threat_score, 2), + "total_kev_cves": len(cves_with_kev), + "total_with_threat_scores": len(threat_scores) + } + + except Exception as e: + logger.error(f"Error getting CISA KEV stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cancel-job/{job_id}") +async def cancel_job(job_id: str, db: Session = Depends(get_db)): + """Cancel a running job""" + try: + # Find the job in the database + job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first() + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + if job.status not in ['pending', 'running']: + raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}") + + # Set cancellation flag + main.job_cancellation_flags[job_id] = True + + # Update job status + job.status = 'cancelled' + job.cancelled_at = datetime.utcnow() + db.commit() + + logger.info(f"Job {job_id} ({job.job_type}) cancelled by user") + + return { + "success": True, + "message": f"Job {job_id} cancelled successfully", + "job_id": job_id, + "job_type": job.job_type + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error cancelling job {job_id}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/running-jobs") +async def get_running_jobs(db: Session = Depends(get_db)): + """Get currently running jobs""" + try: + # Get running jobs from database + running_jobs_db = db.query(BulkProcessingJob).filter( + BulkProcessingJob.status.in_(['pending', 'running']) + ).order_by(BulkProcessingJob.created_at.desc()).all() + + result = [] + for job in running_jobs_db: + result.append({ + 'id': str(job.id), + 'job_type': job.job_type, + 'status': job.status, + 'year': job.year, + 'total_items': job.total_items, + 'processed_items': job.processed_items, + 'failed_items': job.failed_items, + 'error_message': job.error_message, + 'started_at': job.started_at, + 'created_at': job.created_at, + 'can_cancel': job.status in ['pending', 'running'] + }) + + return result + + except Exception as e: + logger.error(f"Error getting running jobs: {e}") + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/backend/routers/llm_operations.py b/backend/routers/llm_operations.py index 1b9b290..b1b26df 100644 --- a/backend/routers/llm_operations.py +++ b/backend/routers/llm_operations.py @@ -1,17 +1,20 @@ from typing import Dict, Any -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks from sqlalchemy.orm import Session from pydantic import BaseModel +import logging from config.database import get_db from models import CVE, SigmaRule router = APIRouter(prefix="/api", tags=["llm-operations"]) +logger = logging.getLogger(__name__) class LLMRuleRequest(BaseModel): - cve_id: str + cve_id: str = None # Optional for bulk operations poc_content: str = "" + force: bool = False # For bulk operations class LLMSwitchRequest(BaseModel): @@ -20,32 +23,170 @@ class LLMSwitchRequest(BaseModel): @router.post("/llm-enhanced-rules") -async def generate_llm_enhanced_rules(request: LLMRuleRequest, db: Session = Depends(get_db)): +async def generate_llm_enhanced_rules(request: LLMRuleRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): """Generate SIGMA rules using LLM AI analysis""" try: from enhanced_sigma_generator import EnhancedSigmaGenerator - # Get CVE - cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first() - if not cve: - raise HTTPException(status_code=404, detail="CVE not found") - - # Generate enhanced rule using LLM - generator = EnhancedSigmaGenerator(db) - result = await generator.generate_enhanced_rule(cve, use_llm=True) - - if result.get('success'): + if request.cve_id: + # Single CVE operation + cve = db.query(CVE).filter(CVE.cve_id == request.cve_id).first() + if not cve: + raise HTTPException(status_code=404, detail="CVE not found") + + # Generate enhanced rule using LLM + generator = EnhancedSigmaGenerator(db) + result = await generator.generate_enhanced_rule(cve, use_llm=True) + + if result.get('success'): + return { + "success": True, + "message": f"Generated LLM-enhanced rule for {request.cve_id}", + "rule_id": result.get('rule_id'), + "generation_method": "llm_enhanced" + } + else: + return { + "success": False, + "error": result.get('error', 'Unknown error'), + "cve_id": request.cve_id + } + else: + # Bulk operation - run in background with job tracking + from models import BulkProcessingJob + import uuid + from datetime import datetime + import main + + # Create job record + job = BulkProcessingJob( + job_type='llm_rule_generation', + status='pending', + job_metadata={ + 'force': request.force + } + ) + db.add(job) + db.commit() + db.refresh(job) + + job_id = str(job.id) + main.running_jobs[job_id] = job + main.job_cancellation_flags[job_id] = False + + async def bulk_llm_generation_task(): + # Create a new database session for the background task + from config.database import SessionLocal + task_db = SessionLocal() + try: + # Get the job in the new session + task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() + if not task_job: + logger.error(f"Job {job_id} not found in task session") + return + + task_job.status = 'running' + task_job.started_at = datetime.utcnow() + task_db.commit() + + generator = EnhancedSigmaGenerator(task_db) + + # Get CVEs with PoC data - limit to small batch initially + if request.force: + # Process all CVEs with PoC data - but limit to prevent system overload + cves_to_process = task_db.query(CVE).filter(CVE.poc_count > 0).limit(50).all() + else: + # Only process CVEs without existing LLM-generated rules - small batch + cves_to_process = task_db.query(CVE).filter( + CVE.poc_count > 0 + ).limit(10).all() + + # Filter out CVEs that already have LLM-generated rules + existing_llm_rules = task_db.query(SigmaRule).filter( + SigmaRule.detection_type.like('llm_%') + ).all() + existing_cve_ids = {rule.cve_id for rule in existing_llm_rules} + cves_to_process = [cve for cve in cves_to_process if cve.cve_id not in existing_cve_ids] + + # Update job with total items + task_job.total_items = len(cves_to_process) + task_db.commit() + + rules_generated = 0 + rules_updated = 0 + failures = 0 + + logger.info(f"Starting bulk LLM rule generation for {len(cves_to_process)} CVEs (job {job_id})") + + for i, cve in enumerate(cves_to_process): + # Check for cancellation + if main.job_cancellation_flags.get(job_id, False): + logger.info(f"Job {job_id} cancelled, stopping LLM generation") + break + + try: + logger.info(f"Processing CVE {i+1}/{len(cves_to_process)}: {cve.cve_id}") + result = await generator.generate_enhanced_rule(cve, use_llm=True) + + if result.get('success'): + if result.get('updated'): + rules_updated += 1 + else: + rules_generated += 1 + logger.info(f"Successfully generated rule for {cve.cve_id}") + else: + failures += 1 + logger.warning(f"Failed to generate rule for {cve.cve_id}: {result.get('error')}") + + # Update progress + task_job.processed_items = i + 1 + task_job.failed_items = failures + task_db.commit() + + except Exception as e: + failures += 1 + logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}") + + # Update progress + task_job.processed_items = i + 1 + task_job.failed_items = failures + task_db.commit() + + # Continue with next CVE even if one fails + continue + + # Update job status if not cancelled + if not main.job_cancellation_flags.get(job_id, False): + task_job.status = 'completed' + task_job.completed_at = datetime.utcnow() + task_db.commit() + logger.info(f"Bulk LLM rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures") + + except Exception as e: + if not main.job_cancellation_flags.get(job_id, False): + # Get the job again in case it was modified + task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first() + if task_job: + task_job.status = 'failed' + task_job.error_message = str(e) + task_job.completed_at = datetime.utcnow() + task_db.commit() + + logger.error(f"Bulk LLM rule generation failed: {e}") + finally: + # Clean up tracking and close the task session + main.running_jobs.pop(job_id, None) + main.job_cancellation_flags.pop(job_id, None) + task_db.close() + + background_tasks.add_task(bulk_llm_generation_task) + return { "success": True, - "message": f"Generated LLM-enhanced rule for {request.cve_id}", - "rule_id": result.get('rule_id'), - "generation_method": "llm_enhanced" - } - else: - return { - "success": False, - "error": result.get('error', 'Unknown error'), - "cve_id": request.cve_id + "message": "Bulk LLM-enhanced rule generation started", + "status": "started", + "job_id": job_id, + "force": request.force } except Exception as e: @@ -58,68 +199,97 @@ async def get_llm_status(): try: from llm_client import LLMClient - # Test all providers - providers_status = {} + # Get current configuration first + current_client = LLMClient() + + # Build available providers list in the format frontend expects + available_providers = [] # Test Ollama try: ollama_client = LLMClient(provider="ollama") ollama_status = await ollama_client.test_connection() - providers_status["ollama"] = { + available_providers.append({ + "name": "ollama", "available": ollama_status.get("available", False), + "default_model": ollama_status.get("current_model", "llama3.2"), "models": ollama_status.get("models", []), - "current_model": ollama_status.get("current_model"), "base_url": ollama_status.get("base_url") - } + }) except Exception as e: - providers_status["ollama"] = {"available": False, "error": str(e)} + available_providers.append({ + "name": "ollama", + "available": False, + "default_model": "llama3.2", + "models": [], + "error": str(e) + }) # Test OpenAI try: openai_client = LLMClient(provider="openai") openai_status = await openai_client.test_connection() - providers_status["openai"] = { + available_providers.append({ + "name": "openai", "available": openai_status.get("available", False), + "default_model": openai_status.get("current_model", "gpt-4o-mini"), "models": openai_status.get("models", []), - "current_model": openai_status.get("current_model"), "has_api_key": openai_status.get("has_api_key", False) - } + }) except Exception as e: - providers_status["openai"] = {"available": False, "error": str(e)} + available_providers.append({ + "name": "openai", + "available": False, + "default_model": "gpt-4o-mini", + "models": [], + "has_api_key": False, + "error": str(e) + }) # Test Anthropic try: anthropic_client = LLMClient(provider="anthropic") anthropic_status = await anthropic_client.test_connection() - providers_status["anthropic"] = { + available_providers.append({ + "name": "anthropic", "available": anthropic_status.get("available", False), + "default_model": anthropic_status.get("current_model", "claude-3-5-sonnet-20241022"), "models": anthropic_status.get("models", []), - "current_model": anthropic_status.get("current_model"), "has_api_key": anthropic_status.get("has_api_key", False) - } + }) except Exception as e: - providers_status["anthropic"] = {"available": False, "error": str(e)} + available_providers.append({ + "name": "anthropic", + "available": False, + "default_model": "claude-3-5-sonnet-20241022", + "models": [], + "has_api_key": False, + "error": str(e) + }) - # Get current configuration - current_client = LLMClient() - current_config = { - "current_provider": current_client.provider, - "current_model": current_client.model, - "default_provider": "ollama" - } + # Determine overall status + any_available = any(p.get("available") for p in available_providers) + status = "ready" if any_available else "not_ready" + # Return in the format the frontend expects return { - "providers": providers_status, - "configuration": current_config, - "status": "operational" if any(p.get("available") for p in providers_status.values()) else "no_providers_available" + "status": status, + "current_provider": { + "provider": current_client.provider, + "model": current_client.model + }, + "available_providers": available_providers } except Exception as e: return { "status": "error", "error": str(e), - "providers": {}, - "configuration": {} + "current_provider": { + "provider": "unknown", + "model": "unknown" + }, + "available_providers": [] } diff --git a/frontend/src/App.js b/frontend/src/App.js index 4f24ccd..489d28a 100644 --- a/frontend/src/App.js +++ b/frontend/src/App.js @@ -52,6 +52,22 @@ function App() { return runningJobTypes.has(jobType); }; + // Helper function to format job types for display + const formatJobType = (jobType) => { + const jobTypeMap = { + 'llm_rule_generation': 'LLM Rule Generation', + 'rule_regeneration': 'Rule Regeneration', + 'bulk_seed': 'Bulk Seed', + 'incremental_update': 'Incremental Update', + 'nomi_sec_sync': 'Nomi-Sec Sync', + 'github_poc_sync': 'GitHub PoC Sync', + 'exploitdb_sync': 'ExploitDB Sync', + 'cisa_kev_sync': 'CISA KEV Sync', + 'reference_sync': 'Reference Sync' + }; + return jobTypeMap[jobType] || jobType; + }; + const isBulkSeedRunning = () => { return isJobTypeRunning('nvd_bulk_seed') || isJobTypeRunning('bulk_seed'); }; @@ -262,7 +278,7 @@ function App() { try { const response = await axios.post('http://localhost:8000/api/sync-references', { batch_size: 30, - max_cves: 100, + max_cves: null, force_resync: false }); console.log('Reference sync response:', response.data); @@ -291,9 +307,20 @@ function App() { force: force }); console.log('LLM rule generation response:', response.data); - fetchData(); + + // For bulk operations, the job runs in background - no need to wait + if (response.data.status === 'started') { + console.log(`LLM rule generation started as background job: ${response.data.job_id}`); + alert(`LLM rule generation started in background. You can monitor progress in the Bulk Jobs tab.`); + // Refresh data to show the new job status + fetchData(); + } else { + // For single CVE operations, refresh data after completion + fetchData(); + } } catch (error) { console.error('Error generating LLM-enhanced rules:', error); + alert('Error generating LLM-enhanced rules. Please check the console for details.'); } }; @@ -1062,7 +1089,7 @@ function App() {
-

{job.job_type}

+

{formatJobType(job.job_type)}