2373 lines
92 KiB
Python
2373 lines
92 KiB
Python
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY, Integer, JSON, func
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from sqlalchemy.dialects.postgresql import UUID
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
import requests
|
|
import json
|
|
import re
|
|
import os
|
|
from typing import List, Optional
|
|
from pydantic import BaseModel
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import base64
|
|
from github import Github
|
|
from urllib.parse import urlparse
|
|
import hashlib
|
|
import logging
|
|
import threading
|
|
from mcdevitt_poc_client import GitHubPoCClient
|
|
from cve2capec_client import CVE2CAPECClient
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global job tracking
|
|
running_jobs = {}
|
|
job_cancellation_flags = {}
|
|
|
|
# Database setup
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
|
|
engine = create_engine(DATABASE_URL)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
# Database Models
|
|
class CVE(Base):
|
|
__tablename__ = "cves"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
cve_id = Column(String(20), unique=True, nullable=False)
|
|
description = Column(Text)
|
|
cvss_score = Column(DECIMAL(3, 1))
|
|
severity = Column(String(20))
|
|
published_date = Column(TIMESTAMP)
|
|
modified_date = Column(TIMESTAMP)
|
|
affected_products = Column(ARRAY(String))
|
|
reference_urls = Column(ARRAY(String))
|
|
# Bulk processing fields
|
|
data_source = Column(String(20), default='nvd_api') # 'nvd_api', 'nvd_bulk', 'manual'
|
|
nvd_json_version = Column(String(10), default='2.0')
|
|
bulk_processed = Column(Boolean, default=False)
|
|
# nomi-sec PoC fields
|
|
poc_count = Column(Integer, default=0)
|
|
poc_data = Column(JSON) # Store nomi-sec PoC metadata
|
|
# Reference data fields
|
|
reference_data = Column(JSON) # Store extracted reference content and analysis
|
|
reference_sync_status = Column(String(20), default='pending') # 'pending', 'processing', 'completed', 'failed'
|
|
reference_last_synced = Column(TIMESTAMP)
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class SigmaRule(Base):
|
|
__tablename__ = "sigma_rules"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
cve_id = Column(String(20))
|
|
rule_name = Column(String(255), nullable=False)
|
|
rule_content = Column(Text, nullable=False)
|
|
detection_type = Column(String(50))
|
|
log_source = Column(String(100))
|
|
confidence_level = Column(String(20))
|
|
auto_generated = Column(Boolean, default=True)
|
|
exploit_based = Column(Boolean, default=False)
|
|
github_repos = Column(ARRAY(String))
|
|
exploit_indicators = Column(Text) # JSON string of extracted indicators
|
|
# Enhanced fields for new data sources
|
|
poc_source = Column(String(20), default='github_search') # 'github_search', 'nomi_sec', 'manual'
|
|
poc_quality_score = Column(Integer, default=0) # Based on star count, activity, etc.
|
|
nomi_sec_data = Column(JSON) # Store nomi-sec PoC metadata
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class RuleTemplate(Base):
|
|
__tablename__ = "rule_templates"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
template_name = Column(String(255), nullable=False)
|
|
template_content = Column(Text, nullable=False)
|
|
applicable_product_patterns = Column(ARRAY(String))
|
|
description = Column(Text)
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class BulkProcessingJob(Base):
|
|
__tablename__ = "bulk_processing_jobs"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
job_type = Column(String(50), nullable=False) # 'nvd_bulk_seed', 'nomi_sec_sync', 'incremental_update'
|
|
status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled'
|
|
year = Column(Integer) # For year-based processing
|
|
total_items = Column(Integer, default=0)
|
|
processed_items = Column(Integer, default=0)
|
|
failed_items = Column(Integer, default=0)
|
|
error_message = Column(Text)
|
|
job_metadata = Column(JSON) # Additional job-specific data
|
|
started_at = Column(TIMESTAMP)
|
|
completed_at = Column(TIMESTAMP)
|
|
cancelled_at = Column(TIMESTAMP)
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
# Pydantic models
|
|
class CVEResponse(BaseModel):
|
|
id: str
|
|
cve_id: str
|
|
description: Optional[str] = None
|
|
cvss_score: Optional[float] = None
|
|
severity: Optional[str] = None
|
|
published_date: Optional[datetime] = None
|
|
affected_products: Optional[List[str]] = None
|
|
reference_urls: Optional[List[str]] = None
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class SigmaRuleResponse(BaseModel):
|
|
id: str
|
|
cve_id: str
|
|
rule_name: str
|
|
rule_content: str
|
|
detection_type: Optional[str] = None
|
|
log_source: Optional[str] = None
|
|
confidence_level: Optional[str] = None
|
|
auto_generated: bool = True
|
|
exploit_based: bool = False
|
|
github_repos: Optional[List[str]] = None
|
|
exploit_indicators: Optional[str] = None
|
|
created_at: datetime
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
# Request models
|
|
class BulkSeedRequest(BaseModel):
|
|
start_year: int = 2002
|
|
end_year: Optional[int] = None
|
|
skip_nvd: bool = False
|
|
skip_nomi_sec: bool = True
|
|
|
|
class NomiSecSyncRequest(BaseModel):
|
|
cve_id: Optional[str] = None
|
|
batch_size: int = 50
|
|
|
|
class GitHubPoCSyncRequest(BaseModel):
|
|
cve_id: Optional[str] = None
|
|
batch_size: int = 50
|
|
|
|
class ExploitDBSyncRequest(BaseModel):
|
|
cve_id: Optional[str] = None
|
|
batch_size: int = 30
|
|
|
|
class CISAKEVSyncRequest(BaseModel):
|
|
cve_id: Optional[str] = None
|
|
batch_size: int = 100
|
|
|
|
class ReferenceSyncRequest(BaseModel):
|
|
cve_id: Optional[str] = None
|
|
batch_size: int = 30
|
|
max_cves: Optional[int] = None
|
|
force_resync: bool = False
|
|
|
|
class RuleRegenRequest(BaseModel):
|
|
force: bool = False
|
|
|
|
# GitHub Exploit Analysis Service
|
|
class GitHubExploitAnalyzer:
|
|
def __init__(self):
|
|
self.github_token = os.getenv("GITHUB_TOKEN")
|
|
self.github = Github(self.github_token) if self.github_token else None
|
|
|
|
async def search_exploits_for_cve(self, cve_id: str) -> List[dict]:
|
|
"""Search GitHub for exploit code related to a CVE"""
|
|
if not self.github:
|
|
print(f"No GitHub token configured, skipping exploit search for {cve_id}")
|
|
return []
|
|
|
|
try:
|
|
print(f"Searching GitHub for exploits for {cve_id}")
|
|
|
|
# Search queries to find exploit code
|
|
search_queries = [
|
|
f"{cve_id} exploit",
|
|
f"{cve_id} poc",
|
|
f"{cve_id} vulnerability",
|
|
f'"{cve_id}" exploit code',
|
|
f"{cve_id.replace('-', '_')} exploit"
|
|
]
|
|
|
|
exploits = []
|
|
seen_repos = set()
|
|
|
|
for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits
|
|
try:
|
|
# Search repositories
|
|
repos = self.github.search_repositories(
|
|
query=query,
|
|
sort="updated",
|
|
order="desc"
|
|
)
|
|
|
|
# Get top 5 results per query
|
|
for repo in repos[:5]:
|
|
if repo.full_name in seen_repos:
|
|
continue
|
|
seen_repos.add(repo.full_name)
|
|
|
|
# Analyze repository
|
|
exploit_info = await self._analyze_repository(repo, cve_id)
|
|
if exploit_info:
|
|
exploits.append(exploit_info)
|
|
|
|
if len(exploits) >= 10: # Limit total exploits
|
|
break
|
|
|
|
if len(exploits) >= 10:
|
|
break
|
|
|
|
except Exception as e:
|
|
print(f"Error searching GitHub with query '{query}': {str(e)}")
|
|
continue
|
|
|
|
print(f"Found {len(exploits)} potential exploits for {cve_id}")
|
|
return exploits
|
|
|
|
except Exception as e:
|
|
print(f"Error searching GitHub for {cve_id}: {str(e)}")
|
|
return []
|
|
|
|
async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]:
|
|
"""Analyze a GitHub repository for exploit code"""
|
|
try:
|
|
# Check if repo name or description mentions the CVE
|
|
repo_text = f"{repo.name} {repo.description or ''}".lower()
|
|
if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text:
|
|
return None
|
|
|
|
# Get repository contents
|
|
exploit_files = []
|
|
indicators = {
|
|
'processes': set(),
|
|
'files': set(),
|
|
'registry': set(),
|
|
'network': set(),
|
|
'commands': set(),
|
|
'powershell': set(),
|
|
'urls': set()
|
|
}
|
|
|
|
try:
|
|
contents = repo.get_contents("")
|
|
for content in contents[:20]: # Limit files to analyze
|
|
if content.type == "file" and self._is_exploit_file(content.name):
|
|
file_analysis = await self._analyze_file_content(repo, content, cve_id)
|
|
if file_analysis:
|
|
exploit_files.append(file_analysis)
|
|
# Merge indicators
|
|
for key, values in file_analysis.get('indicators', {}).items():
|
|
if key in indicators:
|
|
indicators[key].update(values)
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}")
|
|
|
|
if not exploit_files:
|
|
return None
|
|
|
|
return {
|
|
'repo_name': repo.full_name,
|
|
'repo_url': repo.html_url,
|
|
'description': repo.description,
|
|
'language': repo.language,
|
|
'stars': repo.stargazers_count,
|
|
'updated': repo.updated_at.isoformat(),
|
|
'files': exploit_files,
|
|
'indicators': {k: list(v) for k, v in indicators.items()}
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing repository {repo.full_name}: {str(e)}")
|
|
return None
|
|
|
|
def _is_exploit_file(self, filename: str) -> bool:
|
|
"""Check if a file is likely to contain exploit code"""
|
|
exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java']
|
|
exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack']
|
|
|
|
filename_lower = filename.lower()
|
|
|
|
# Check extension
|
|
if not any(filename_lower.endswith(ext) for ext in exploit_extensions):
|
|
return False
|
|
|
|
# Check filename for exploit-related terms
|
|
return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower
|
|
|
|
async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]:
|
|
"""Analyze individual file content for exploit indicators"""
|
|
try:
|
|
if file_content.size > 100000: # Skip files larger than 100KB
|
|
return None
|
|
|
|
# Decode file content
|
|
content = file_content.decoded_content.decode('utf-8', errors='ignore')
|
|
|
|
# Check if file actually mentions the CVE
|
|
if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower():
|
|
return None
|
|
|
|
indicators = self._extract_indicators_from_code(content, file_content.name)
|
|
|
|
if not any(indicators.values()):
|
|
return None
|
|
|
|
return {
|
|
'filename': file_content.name,
|
|
'path': file_content.path,
|
|
'size': file_content.size,
|
|
'indicators': indicators
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing file {file_content.name}: {str(e)}")
|
|
return None
|
|
|
|
def _extract_indicators_from_code(self, content: str, filename: str) -> dict:
|
|
"""Extract security indicators from exploit code"""
|
|
indicators = {
|
|
'processes': set(),
|
|
'files': set(),
|
|
'registry': set(),
|
|
'network': set(),
|
|
'commands': set(),
|
|
'powershell': set(),
|
|
'urls': set()
|
|
}
|
|
|
|
# Process patterns
|
|
process_patterns = [
|
|
r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']',
|
|
r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']',
|
|
r'system\s*\(\s*["\']([^"\']+)["\']',
|
|
r'exec\s*\(\s*["\']([^"\']+)["\']',
|
|
r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']'
|
|
]
|
|
|
|
# File patterns
|
|
file_patterns = [
|
|
r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']',
|
|
r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?',
|
|
r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)',
|
|
r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)'
|
|
]
|
|
|
|
# Registry patterns
|
|
registry_patterns = [
|
|
r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']',
|
|
r'HKEY_[A-Z_]+\\\\([^"\'\\]+)',
|
|
r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?'
|
|
]
|
|
|
|
# Network patterns
|
|
network_patterns = [
|
|
r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)',
|
|
r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)',
|
|
r'(?:http|https|ftp)://([^\s"\'<>]+)',
|
|
r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)'
|
|
]
|
|
|
|
# PowerShell patterns
|
|
powershell_patterns = [
|
|
r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
|
|
r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?',
|
|
r'Start-Process\s+["\']?([^"\']+)["\']?',
|
|
r'Get-Process\s+["\']?([^"\']+)["\']?'
|
|
]
|
|
|
|
# Command patterns
|
|
command_patterns = [
|
|
r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
|
|
r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)',
|
|
r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)'
|
|
]
|
|
|
|
# Extract indicators using regex patterns
|
|
patterns = {
|
|
'processes': process_patterns,
|
|
'files': file_patterns,
|
|
'registry': registry_patterns,
|
|
'powershell': powershell_patterns,
|
|
'commands': command_patterns
|
|
}
|
|
|
|
for category, pattern_list in patterns.items():
|
|
for pattern in pattern_list:
|
|
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
|
|
for match in matches:
|
|
if isinstance(match, tuple):
|
|
indicators[category].add(match[0])
|
|
else:
|
|
indicators[category].add(match)
|
|
|
|
# Special handling for network indicators
|
|
for pattern in network_patterns:
|
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
|
for match in matches:
|
|
if isinstance(match, tuple):
|
|
if len(match) >= 2:
|
|
indicators['network'].add(f"{match[0]}:{match[1]}")
|
|
else:
|
|
indicators['network'].add(match[0])
|
|
else:
|
|
indicators['network'].add(match)
|
|
|
|
# Convert sets to lists and filter out empty/invalid indicators
|
|
cleaned_indicators = {}
|
|
for key, values in indicators.items():
|
|
cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200]
|
|
if cleaned_values:
|
|
cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category
|
|
|
|
return cleaned_indicators
|
|
class CVESigmaService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.nvd_api_key = os.getenv("NVD_API_KEY")
|
|
|
|
async def fetch_recent_cves(self, days_back: int = 7):
|
|
"""Fetch recent CVEs from NVD API"""
|
|
end_date = datetime.utcnow()
|
|
start_date = end_date - timedelta(days=days_back)
|
|
|
|
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
|
params = {
|
|
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
|
|
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
|
|
"resultsPerPage": 100
|
|
}
|
|
|
|
headers = {}
|
|
if self.nvd_api_key:
|
|
headers["apiKey"] = self.nvd_api_key
|
|
|
|
try:
|
|
response = requests.get(url, params=params, headers=headers, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
new_cves = []
|
|
for vuln in data.get("vulnerabilities", []):
|
|
cve_data = vuln.get("cve", {})
|
|
cve_id = cve_data.get("id")
|
|
|
|
# Check if CVE already exists
|
|
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if existing:
|
|
continue
|
|
|
|
# Extract CVE information
|
|
description = ""
|
|
if cve_data.get("descriptions"):
|
|
description = cve_data["descriptions"][0].get("value", "")
|
|
|
|
cvss_score = None
|
|
severity = None
|
|
if cve_data.get("metrics", {}).get("cvssMetricV31"):
|
|
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
|
|
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
|
|
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
|
|
|
|
affected_products = []
|
|
if cve_data.get("configurations"):
|
|
for config in cve_data["configurations"]:
|
|
for node in config.get("nodes", []):
|
|
for cpe_match in node.get("cpeMatch", []):
|
|
if cpe_match.get("vulnerable"):
|
|
affected_products.append(cpe_match.get("criteria", ""))
|
|
|
|
reference_urls = []
|
|
if cve_data.get("references"):
|
|
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
|
|
|
|
cve_obj = CVE(
|
|
cve_id=cve_id,
|
|
description=description,
|
|
cvss_score=cvss_score,
|
|
severity=severity,
|
|
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
|
|
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
|
|
affected_products=affected_products,
|
|
reference_urls=reference_urls
|
|
)
|
|
|
|
self.db.add(cve_obj)
|
|
new_cves.append(cve_obj)
|
|
|
|
self.db.commit()
|
|
return new_cves
|
|
|
|
except Exception as e:
|
|
print(f"Error fetching CVEs: {str(e)}")
|
|
return []
|
|
|
|
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
|
|
"""Generate SIGMA rule based on CVE data"""
|
|
if not cve.description:
|
|
return None
|
|
|
|
# Analyze CVE to determine appropriate template
|
|
description_lower = cve.description.lower()
|
|
affected_products = [p.lower() for p in (cve.affected_products or [])]
|
|
|
|
template = self._select_template(description_lower, affected_products)
|
|
if not template:
|
|
return None
|
|
|
|
# Generate rule content
|
|
rule_content = self._populate_template(cve, template)
|
|
if not rule_content:
|
|
return None
|
|
|
|
# Determine detection type and confidence
|
|
detection_type = self._determine_detection_type(description_lower)
|
|
confidence_level = self._calculate_confidence(cve)
|
|
|
|
sigma_rule = SigmaRule(
|
|
cve_id=cve.cve_id,
|
|
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
|
|
rule_content=rule_content,
|
|
detection_type=detection_type,
|
|
log_source=template.template_name.lower().replace(" ", "_"),
|
|
confidence_level=confidence_level,
|
|
auto_generated=True
|
|
)
|
|
|
|
self.db.add(sigma_rule)
|
|
return sigma_rule
|
|
|
|
def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None):
|
|
"""Select appropriate SIGMA rule template based on CVE and exploit analysis"""
|
|
templates = self.db.query(RuleTemplate).all()
|
|
|
|
# If we have exploit indicators, use them to determine the best template
|
|
if exploit_indicators:
|
|
if exploit_indicators.get('powershell'):
|
|
powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None)
|
|
if powershell_template:
|
|
return powershell_template
|
|
|
|
if exploit_indicators.get('network'):
|
|
network_template = next((t for t in templates if "Network Connection" in t.template_name), None)
|
|
if network_template:
|
|
return network_template
|
|
|
|
if exploit_indicators.get('files'):
|
|
file_template = next((t for t in templates if "File Modification" in t.template_name), None)
|
|
if file_template:
|
|
return file_template
|
|
|
|
if exploit_indicators.get('processes') or exploit_indicators.get('commands'):
|
|
process_template = next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
if process_template:
|
|
return process_template
|
|
|
|
# Fallback to original logic
|
|
if any("windows" in p or "microsoft" in p for p in affected_products):
|
|
if "process" in description or "execution" in description:
|
|
return next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
elif "network" in description or "remote" in description:
|
|
return next((t for t in templates if "Network Connection" in t.template_name), None)
|
|
elif "file" in description or "write" in description:
|
|
return next((t for t in templates if "File Modification" in t.template_name), None)
|
|
|
|
# Default to process execution template
|
|
return next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
|
|
def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str:
|
|
"""Populate template with CVE-specific data and exploit indicators"""
|
|
try:
|
|
# Use exploit indicators if available, otherwise extract from description
|
|
if exploit_indicators:
|
|
suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', [])
|
|
suspicious_ports = []
|
|
file_patterns = exploit_indicators.get('files', [])
|
|
|
|
# Extract ports from network indicators
|
|
for net_indicator in exploit_indicators.get('network', []):
|
|
if ':' in str(net_indicator):
|
|
try:
|
|
port = int(str(net_indicator).split(':')[-1])
|
|
suspicious_ports.append(port)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
# Fallback to original extraction
|
|
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
|
|
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
|
|
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
|
|
|
|
# Determine severity level
|
|
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
|
|
|
|
# Create enhanced description
|
|
enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description
|
|
if exploit_indicators:
|
|
enhanced_description += " [Enhanced with GitHub exploit analysis]"
|
|
|
|
# Build tags
|
|
tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()]
|
|
if exploit_indicators:
|
|
tags.append("exploit.github")
|
|
|
|
rule_content = template.template_content.format(
|
|
title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection",
|
|
description=enhanced_description,
|
|
rule_id=str(uuid.uuid4()),
|
|
date=datetime.utcnow().strftime("%Y/%m/%d"),
|
|
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
|
|
cve_id=cve.cve_id.lower(),
|
|
tags="\n - ".join(tags),
|
|
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
|
|
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
|
|
file_patterns=file_patterns or ["temp", "malware", "exploit"],
|
|
level=level
|
|
)
|
|
|
|
return rule_content
|
|
|
|
except Exception as e:
|
|
print(f"Error populating template: {str(e)}")
|
|
return None
|
|
|
|
def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str:
|
|
"""Map CVE and exploit indicators to MITRE ATT&CK techniques"""
|
|
desc_lower = description.lower()
|
|
|
|
# Check exploit indicators first
|
|
if exploit_indicators:
|
|
if exploit_indicators.get('powershell'):
|
|
return "t1059.001" # PowerShell
|
|
elif exploit_indicators.get('commands'):
|
|
return "t1059.003" # Windows Command Shell
|
|
elif exploit_indicators.get('network'):
|
|
return "t1071.001" # Web Protocols
|
|
elif exploit_indicators.get('files'):
|
|
return "t1105" # Ingress Tool Transfer
|
|
elif exploit_indicators.get('processes'):
|
|
return "t1106" # Native API
|
|
|
|
# Fallback to description analysis
|
|
if "powershell" in desc_lower:
|
|
return "t1059.001"
|
|
elif "command" in desc_lower or "cmd" in desc_lower:
|
|
return "t1059.003"
|
|
elif "network" in desc_lower or "remote" in desc_lower:
|
|
return "t1071.001"
|
|
elif "file" in desc_lower or "upload" in desc_lower:
|
|
return "t1105"
|
|
elif "process" in desc_lower or "execution" in desc_lower:
|
|
return "t1106"
|
|
else:
|
|
return "execution" # Generic
|
|
|
|
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
|
|
"""Extract suspicious indicators from CVE description"""
|
|
if indicator_type == "process":
|
|
# Look for executable names or process patterns
|
|
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
|
|
return exe_pattern[:5] if exe_pattern else None
|
|
|
|
elif indicator_type == "port":
|
|
# Look for port numbers
|
|
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
|
|
return [int(p) for p in port_pattern[:3]] if port_pattern else None
|
|
|
|
elif indicator_type == "file":
|
|
# Look for file extensions or paths
|
|
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
|
|
return file_pattern[:5] if file_pattern else None
|
|
|
|
return None
|
|
|
|
def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str:
|
|
"""Determine detection type based on CVE description and exploit indicators"""
|
|
if exploit_indicators:
|
|
if exploit_indicators.get('powershell'):
|
|
return "powershell"
|
|
elif exploit_indicators.get('network'):
|
|
return "network"
|
|
elif exploit_indicators.get('files'):
|
|
return "file"
|
|
elif exploit_indicators.get('processes') or exploit_indicators.get('commands'):
|
|
return "process"
|
|
|
|
# Fallback to original logic
|
|
if "remote" in description or "network" in description:
|
|
return "network"
|
|
elif "process" in description or "execution" in description:
|
|
return "process"
|
|
elif "file" in description or "filesystem" in description:
|
|
return "file"
|
|
else:
|
|
return "general"
|
|
|
|
def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str:
|
|
"""Calculate confidence level for the generated rule"""
|
|
base_confidence = 0
|
|
|
|
# CVSS score contributes to confidence
|
|
if cve.cvss_score:
|
|
if cve.cvss_score >= 9.0:
|
|
base_confidence += 3
|
|
elif cve.cvss_score >= 7.0:
|
|
base_confidence += 2
|
|
else:
|
|
base_confidence += 1
|
|
|
|
# Exploit-based rules get higher confidence
|
|
if exploit_based:
|
|
base_confidence += 2
|
|
|
|
# Map to confidence levels
|
|
if base_confidence >= 4:
|
|
return "high"
|
|
elif base_confidence >= 2:
|
|
return "medium"
|
|
else:
|
|
return "low"
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
# Background task to fetch CVEs and generate rules
|
|
async def background_cve_fetch():
|
|
retry_count = 0
|
|
max_retries = 3
|
|
|
|
while True:
|
|
try:
|
|
db = SessionLocal()
|
|
service = CVESigmaService(db)
|
|
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
|
|
print(f"[{current_time}] Starting CVE fetch cycle...")
|
|
|
|
# Use a longer initial period (30 days) to find CVEs
|
|
new_cves = await service.fetch_recent_cves(days_back=30)
|
|
|
|
if new_cves:
|
|
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
|
|
rules_generated = 0
|
|
for cve in new_cves:
|
|
try:
|
|
sigma_rule = service.generate_sigma_rule(cve)
|
|
if sigma_rule:
|
|
rules_generated += 1
|
|
print(f"Generated SIGMA rule for {cve.cve_id}")
|
|
else:
|
|
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
|
|
except Exception as e:
|
|
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
|
|
|
|
db.commit()
|
|
print(f"Successfully generated {rules_generated} SIGMA rules")
|
|
retry_count = 0 # Reset retry count on success
|
|
else:
|
|
print("No new CVEs found in this cycle")
|
|
# After first successful run, reduce to 7 days for regular updates
|
|
if retry_count == 0:
|
|
print("Switching to 7-day lookback for future runs...")
|
|
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
retry_count += 1
|
|
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
|
|
if retry_count >= max_retries:
|
|
print(f"Max retries reached, waiting longer before next attempt...")
|
|
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
|
|
retry_count = 0
|
|
else:
|
|
await asyncio.sleep(300) # Wait 5 minutes before retry
|
|
continue
|
|
|
|
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
|
|
wait_time = 3600 if retry_count == 0 else 1800
|
|
print(f"Next CVE fetch in {wait_time//60} minutes...")
|
|
await asyncio.sleep(wait_time)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Initialize database
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Initialize rule templates
|
|
db = SessionLocal()
|
|
try:
|
|
existing_templates = db.query(RuleTemplate).count()
|
|
if existing_templates == 0:
|
|
logger.info("No rule templates found. Database initialization will handle template creation.")
|
|
except Exception as e:
|
|
logger.error(f"Error checking rule templates: {e}")
|
|
finally:
|
|
db.close()
|
|
|
|
# Initialize and start the job scheduler
|
|
try:
|
|
from job_scheduler import initialize_scheduler
|
|
from job_executors import register_all_executors
|
|
|
|
# Initialize scheduler
|
|
scheduler = initialize_scheduler()
|
|
scheduler.set_db_session_factory(SessionLocal)
|
|
|
|
# Register all job executors
|
|
register_all_executors(scheduler)
|
|
|
|
# Start the scheduler
|
|
scheduler.start()
|
|
|
|
logger.info("Job scheduler initialized and started")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing job scheduler: {e}")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
scheduler = get_scheduler()
|
|
scheduler.stop()
|
|
logger.info("Job scheduler stopped")
|
|
except Exception as e:
|
|
logger.error(f"Error stopping job scheduler: {e}")
|
|
|
|
# FastAPI app
|
|
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.get("/api/cves", response_model=List[CVEResponse])
|
|
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
|
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
|
|
# Convert UUID to string for each CVE
|
|
result = []
|
|
for cve in cves:
|
|
cve_dict = {
|
|
'id': str(cve.id),
|
|
'cve_id': cve.cve_id,
|
|
'description': cve.description,
|
|
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
|
|
'severity': cve.severity,
|
|
'published_date': cve.published_date,
|
|
'affected_products': cve.affected_products,
|
|
'reference_urls': cve.reference_urls
|
|
}
|
|
result.append(CVEResponse(**cve_dict))
|
|
return result
|
|
|
|
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
|
|
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
|
|
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if not cve:
|
|
raise HTTPException(status_code=404, detail="CVE not found")
|
|
|
|
cve_dict = {
|
|
'id': str(cve.id),
|
|
'cve_id': cve.cve_id,
|
|
'description': cve.description,
|
|
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
|
|
'severity': cve.severity,
|
|
'published_date': cve.published_date,
|
|
'affected_products': cve.affected_products,
|
|
'reference_urls': cve.reference_urls
|
|
}
|
|
return CVEResponse(**cve_dict)
|
|
|
|
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
|
|
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
|
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
|
|
# Convert UUID to string for each rule
|
|
result = []
|
|
for rule in rules:
|
|
rule_dict = {
|
|
'id': str(rule.id),
|
|
'cve_id': rule.cve_id,
|
|
'rule_name': rule.rule_name,
|
|
'rule_content': rule.rule_content,
|
|
'detection_type': rule.detection_type,
|
|
'log_source': rule.log_source,
|
|
'confidence_level': rule.confidence_level,
|
|
'auto_generated': rule.auto_generated,
|
|
'exploit_based': rule.exploit_based or False,
|
|
'github_repos': rule.github_repos or [],
|
|
'exploit_indicators': rule.exploit_indicators,
|
|
'created_at': rule.created_at
|
|
}
|
|
result.append(SigmaRuleResponse(**rule_dict))
|
|
return result
|
|
|
|
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
|
|
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
|
|
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
|
|
# Convert UUID to string for each rule
|
|
result = []
|
|
for rule in rules:
|
|
rule_dict = {
|
|
'id': str(rule.id),
|
|
'cve_id': rule.cve_id,
|
|
'rule_name': rule.rule_name,
|
|
'rule_content': rule.rule_content,
|
|
'detection_type': rule.detection_type,
|
|
'log_source': rule.log_source,
|
|
'confidence_level': rule.confidence_level,
|
|
'auto_generated': rule.auto_generated,
|
|
'exploit_based': rule.exploit_based or False,
|
|
'github_repos': rule.github_repos or [],
|
|
'exploit_indicators': rule.exploit_indicators,
|
|
'created_at': rule.created_at
|
|
}
|
|
result.append(SigmaRuleResponse(**rule_dict))
|
|
return result
|
|
|
|
@app.post("/api/fetch-cves")
|
|
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
async def fetch_task():
|
|
try:
|
|
service = CVESigmaService(db)
|
|
print("Manual CVE fetch initiated...")
|
|
# Use 30 days for manual fetch to get more results
|
|
new_cves = await service.fetch_recent_cves(days_back=30)
|
|
|
|
rules_generated = 0
|
|
for cve in new_cves:
|
|
sigma_rule = service.generate_sigma_rule(cve)
|
|
if sigma_rule:
|
|
rules_generated += 1
|
|
|
|
db.commit()
|
|
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
|
|
except Exception as e:
|
|
print(f"Manual fetch error: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(fetch_task)
|
|
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
|
|
|
|
@app.get("/api/test-nvd")
|
|
async def test_nvd_connection():
|
|
"""Test endpoint to check NVD API connectivity"""
|
|
try:
|
|
# Test with a simple request using current date
|
|
end_date = datetime.utcnow()
|
|
start_date = end_date - timedelta(days=30)
|
|
|
|
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
|
|
params = {
|
|
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
|
|
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
|
|
"resultsPerPage": 5,
|
|
"startIndex": 0
|
|
}
|
|
|
|
headers = {
|
|
"User-Agent": "CVE-SIGMA-Generator/1.0",
|
|
"Accept": "application/json"
|
|
}
|
|
|
|
nvd_api_key = os.getenv("NVD_API_KEY")
|
|
if nvd_api_key:
|
|
headers["apiKey"] = nvd_api_key
|
|
|
|
print(f"Testing NVD API with URL: {url}")
|
|
print(f"Test params: {params}")
|
|
print(f"Test headers: {headers}")
|
|
|
|
response = requests.get(url, params=params, headers=headers, timeout=15)
|
|
|
|
result = {
|
|
"status": "success" if response.status_code == 200 else "error",
|
|
"status_code": response.status_code,
|
|
"has_api_key": bool(nvd_api_key),
|
|
"request_url": f"{url}?{requests.compat.urlencode(params)}",
|
|
"response_headers": dict(response.headers)
|
|
}
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
result.update({
|
|
"total_results": data.get("totalResults", 0),
|
|
"results_per_page": data.get("resultsPerPage", 0),
|
|
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
|
|
"message": "NVD API is accessible and returning data"
|
|
})
|
|
else:
|
|
result.update({
|
|
"error_message": response.text[:200],
|
|
"message": f"NVD API returned {response.status_code}"
|
|
})
|
|
|
|
# Try fallback without date filters if we get 404
|
|
if response.status_code == 404:
|
|
print("Trying fallback without date filters...")
|
|
fallback_params = {
|
|
"resultsPerPage": 5,
|
|
"startIndex": 0
|
|
}
|
|
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
|
|
result["fallback_status_code"] = fallback_response.status_code
|
|
|
|
if fallback_response.status_code == 200:
|
|
fallback_data = fallback_response.json()
|
|
result.update({
|
|
"fallback_success": True,
|
|
"fallback_total_results": fallback_data.get("totalResults", 0),
|
|
"message": "NVD API works without date filters"
|
|
})
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"NVD API test error: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"message": f"Failed to connect to NVD API: {str(e)}"
|
|
}
|
|
|
|
@app.get("/api/stats")
|
|
async def get_stats(db: Session = Depends(get_db)):
|
|
total_cves = db.query(CVE).count()
|
|
total_rules = db.query(SigmaRule).count()
|
|
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
|
|
|
|
# Enhanced stats with bulk processing info
|
|
bulk_processed_cves = db.query(CVE).filter(CVE.bulk_processed == True).count()
|
|
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).count()
|
|
nomi_sec_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'nomi_sec').count()
|
|
|
|
return {
|
|
"total_cves": total_cves,
|
|
"total_sigma_rules": total_rules,
|
|
"recent_cves_7_days": recent_cves,
|
|
"bulk_processed_cves": bulk_processed_cves,
|
|
"cves_with_pocs": cves_with_pocs,
|
|
"nomi_sec_rules": nomi_sec_rules,
|
|
"poc_coverage": (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0,
|
|
"nomi_sec_coverage": (nomi_sec_rules / total_rules * 100) if total_rules > 0 else 0
|
|
}
|
|
|
|
# New bulk processing endpoints
|
|
@app.post("/api/bulk-seed")
|
|
async def start_bulk_seed(background_tasks: BackgroundTasks,
|
|
request: BulkSeedRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Start bulk seeding process"""
|
|
|
|
async def bulk_seed_task():
|
|
try:
|
|
from bulk_seeder import BulkSeeder
|
|
seeder = BulkSeeder(db)
|
|
result = await seeder.full_bulk_seed(
|
|
start_year=request.start_year,
|
|
end_year=request.end_year,
|
|
skip_nvd=request.skip_nvd,
|
|
skip_nomi_sec=request.skip_nomi_sec
|
|
)
|
|
logger.info(f"Bulk seed completed: {result}")
|
|
except Exception as e:
|
|
logger.error(f"Bulk seed failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(bulk_seed_task)
|
|
|
|
return {
|
|
"message": "Bulk seeding process started",
|
|
"status": "started",
|
|
"start_year": request.start_year,
|
|
"end_year": request.end_year or datetime.now().year,
|
|
"skip_nvd": request.skip_nvd,
|
|
"skip_nomi_sec": request.skip_nomi_sec
|
|
}
|
|
|
|
@app.post("/api/incremental-update")
|
|
async def start_incremental_update(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Start incremental update process"""
|
|
|
|
async def incremental_update_task():
|
|
try:
|
|
from bulk_seeder import BulkSeeder
|
|
seeder = BulkSeeder(db)
|
|
result = await seeder.incremental_update()
|
|
logger.info(f"Incremental update completed: {result}")
|
|
except Exception as e:
|
|
logger.error(f"Incremental update failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(incremental_update_task)
|
|
|
|
return {
|
|
"message": "Incremental update process started",
|
|
"status": "started"
|
|
}
|
|
|
|
@app.post("/api/sync-nomi-sec")
|
|
async def sync_nomi_sec(background_tasks: BackgroundTasks,
|
|
request: NomiSecSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize nomi-sec PoC data"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='nomi_sec_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
running_jobs[job_id] = job
|
|
job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
try:
|
|
job.status = 'running'
|
|
job.started_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
from nomi_sec_client import NomiSecClient
|
|
client = NomiSecClient(db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_pocs(request.cve_id)
|
|
logger.info(f"Nomi-sec sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_all_cves(
|
|
batch_size=request.batch_size,
|
|
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
|
)
|
|
logger.info(f"Nomi-sec bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
job.status = 'completed'
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
job.status = 'failed'
|
|
job.error_message = str(e)
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
logger.error(f"Nomi-sec sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking
|
|
running_jobs.pop(job_id, None)
|
|
job_cancellation_flags.pop(job_id, None)
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"Nomi-sec sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
@app.post("/api/sync-github-pocs")
|
|
async def sync_github_pocs(background_tasks: BackgroundTasks,
|
|
request: GitHubPoCSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize GitHub PoC data"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='github_poc_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
running_jobs[job_id] = job
|
|
job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
try:
|
|
job.status = 'running'
|
|
job.started_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
client = GitHubPoCClient(db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_pocs(request.cve_id)
|
|
logger.info(f"GitHub PoC sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_all_cves(batch_size=request.batch_size)
|
|
logger.info(f"GitHub PoC bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
job.status = 'completed'
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
except Exception as e:
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
job.status = 'failed'
|
|
job.error_message = str(e)
|
|
job.completed_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
logger.error(f"GitHub PoC sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking
|
|
running_jobs.pop(job_id, None)
|
|
job_cancellation_flags.pop(job_id, None)
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"GitHub PoC sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
@app.post("/api/sync-exploitdb")
|
|
async def sync_exploitdb(background_tasks: BackgroundTasks,
|
|
request: ExploitDBSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize ExploitDB data from git mirror"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='exploitdb_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
running_jobs[job_id] = job
|
|
job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
# Create a new database session for the background task
|
|
task_db = SessionLocal()
|
|
try:
|
|
# Get the job in the new session
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if not task_job:
|
|
logger.error(f"Job {job_id} not found in task session")
|
|
return
|
|
|
|
task_job.status = 'running'
|
|
task_job.started_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
from exploitdb_client_local import ExploitDBLocalClient
|
|
client = ExploitDBLocalClient(task_db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_exploits(request.cve_id)
|
|
logger.info(f"ExploitDB sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_exploitdb(
|
|
batch_size=request.batch_size,
|
|
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
|
)
|
|
logger.info(f"ExploitDB bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
task_job.status = 'completed'
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
except Exception as e:
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
# Get the job again in case it was modified
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if task_job:
|
|
task_job.status = 'failed'
|
|
task_job.error_message = str(e)
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
logger.error(f"ExploitDB sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking and close the task session
|
|
running_jobs.pop(job_id, None)
|
|
job_cancellation_flags.pop(job_id, None)
|
|
task_db.close()
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"ExploitDB sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
@app.post("/api/sync-cisa-kev")
|
|
async def sync_cisa_kev(background_tasks: BackgroundTasks,
|
|
request: CISAKEVSyncRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Synchronize CISA Known Exploited Vulnerabilities data"""
|
|
|
|
# Create job record
|
|
job = BulkProcessingJob(
|
|
job_type='cisa_kev_sync',
|
|
status='pending',
|
|
job_metadata={
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size
|
|
}
|
|
)
|
|
db.add(job)
|
|
db.commit()
|
|
db.refresh(job)
|
|
|
|
job_id = str(job.id)
|
|
running_jobs[job_id] = job
|
|
job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
# Create a new database session for the background task
|
|
task_db = SessionLocal()
|
|
try:
|
|
# Get the job in the new session
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if not task_job:
|
|
logger.error(f"Job {job_id} not found in task session")
|
|
return
|
|
|
|
task_job.status = 'running'
|
|
task_job.started_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
from cisa_kev_client import CISAKEVClient
|
|
client = CISAKEVClient(task_db)
|
|
|
|
if request.cve_id:
|
|
# Sync specific CVE
|
|
if job_cancellation_flags.get(job_id, False):
|
|
logger.info(f"Job {job_id} cancelled before starting")
|
|
return
|
|
|
|
result = await client.sync_cve_kev_data(request.cve_id)
|
|
logger.info(f"CISA KEV sync for {request.cve_id}: {result}")
|
|
else:
|
|
# Sync all CVEs with cancellation support
|
|
result = await client.bulk_sync_kev_data(
|
|
batch_size=request.batch_size,
|
|
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
|
)
|
|
logger.info(f"CISA KEV bulk sync completed: {result}")
|
|
|
|
# Update job status if not cancelled
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
task_job.status = 'completed'
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
except Exception as e:
|
|
if not job_cancellation_flags.get(job_id, False):
|
|
# Get the job again in case it was modified
|
|
task_job = task_db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job.id).first()
|
|
if task_job:
|
|
task_job.status = 'failed'
|
|
task_job.error_message = str(e)
|
|
task_job.completed_at = datetime.utcnow()
|
|
task_db.commit()
|
|
|
|
logger.error(f"CISA KEV sync failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up tracking and close the task session
|
|
running_jobs.pop(job_id, None)
|
|
job_cancellation_flags.pop(job_id, None)
|
|
task_db.close()
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"CISA KEV sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size
|
|
}
|
|
|
|
@app.post("/api/sync-references")
|
|
async def sync_references(request: ReferenceSyncRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Start reference data synchronization"""
|
|
|
|
try:
|
|
from reference_client import ReferenceClient
|
|
client = ReferenceClient(db)
|
|
|
|
# Create job ID
|
|
job_id = str(uuid.uuid4())
|
|
|
|
# Add job to tracking
|
|
running_jobs[job_id] = {
|
|
'type': 'reference_sync',
|
|
'status': 'running',
|
|
'cve_id': request.cve_id,
|
|
'batch_size': request.batch_size,
|
|
'max_cves': request.max_cves,
|
|
'force_resync': request.force_resync,
|
|
'started_at': datetime.utcnow()
|
|
}
|
|
|
|
# Create cancellation flag
|
|
job_cancellation_flags[job_id] = False
|
|
|
|
async def sync_task():
|
|
try:
|
|
if request.cve_id:
|
|
# Single CVE sync
|
|
result = await client.sync_cve_references(request.cve_id)
|
|
running_jobs[job_id]['result'] = result
|
|
running_jobs[job_id]['status'] = 'completed'
|
|
else:
|
|
# Bulk sync
|
|
result = await client.bulk_sync_references(
|
|
batch_size=request.batch_size,
|
|
max_cves=request.max_cves,
|
|
force_resync=request.force_resync,
|
|
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
|
)
|
|
running_jobs[job_id]['result'] = result
|
|
running_jobs[job_id]['status'] = 'completed'
|
|
|
|
running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Reference sync task failed: {e}")
|
|
running_jobs[job_id]['status'] = 'failed'
|
|
running_jobs[job_id]['error'] = str(e)
|
|
running_jobs[job_id]['completed_at'] = datetime.utcnow()
|
|
finally:
|
|
# Clean up cancellation flag
|
|
job_cancellation_flags.pop(job_id, None)
|
|
|
|
background_tasks.add_task(sync_task)
|
|
|
|
return {
|
|
"message": f"Reference sync started" + (f" for {request.cve_id}" if request.cve_id else " for all CVEs"),
|
|
"status": "started",
|
|
"job_id": job_id,
|
|
"cve_id": request.cve_id,
|
|
"batch_size": request.batch_size,
|
|
"max_cves": request.max_cves,
|
|
"force_resync": request.force_resync
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to start reference sync: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to start reference sync: {str(e)}")
|
|
|
|
@app.get("/api/reference-stats")
|
|
async def get_reference_stats(db: Session = Depends(get_db)):
|
|
"""Get reference synchronization statistics"""
|
|
|
|
try:
|
|
from reference_client import ReferenceClient
|
|
client = ReferenceClient(db)
|
|
|
|
# Get sync status
|
|
status = await client.get_reference_sync_status()
|
|
|
|
# Get quality distribution from reference data
|
|
quality_distribution = {}
|
|
from sqlalchemy import text
|
|
cves_with_references = db.query(CVE).filter(
|
|
text("reference_data::text LIKE '%\"reference_analysis\"%'")
|
|
).all()
|
|
|
|
for cve in cves_with_references:
|
|
if cve.reference_data and 'reference_analysis' in cve.reference_data:
|
|
ref_analysis = cve.reference_data['reference_analysis']
|
|
high_conf_refs = ref_analysis.get('high_confidence_references', 0)
|
|
total_refs = ref_analysis.get('reference_count', 0)
|
|
|
|
if total_refs > 0:
|
|
quality_ratio = high_conf_refs / total_refs
|
|
if quality_ratio >= 0.8:
|
|
quality_tier = 'excellent'
|
|
elif quality_ratio >= 0.6:
|
|
quality_tier = 'good'
|
|
elif quality_ratio >= 0.4:
|
|
quality_tier = 'fair'
|
|
else:
|
|
quality_tier = 'poor'
|
|
|
|
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
|
|
|
|
# Get reference type distribution
|
|
reference_type_distribution = {}
|
|
for cve in cves_with_references:
|
|
if cve.reference_data and 'reference_analysis' in cve.reference_data:
|
|
ref_analysis = cve.reference_data['reference_analysis']
|
|
ref_types = ref_analysis.get('reference_types', [])
|
|
for ref_type in ref_types:
|
|
reference_type_distribution[ref_type] = reference_type_distribution.get(ref_type, 0) + 1
|
|
|
|
return {
|
|
'reference_sync_status': status,
|
|
'quality_distribution': quality_distribution,
|
|
'reference_type_distribution': reference_type_distribution,
|
|
'total_with_reference_analysis': len(cves_with_references),
|
|
'source': 'reference_extraction'
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get reference stats: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to get reference stats: {str(e)}")
|
|
|
|
@app.get("/api/exploitdb-stats")
|
|
async def get_exploitdb_stats(db: Session = Depends(get_db)):
|
|
"""Get ExploitDB-related statistics"""
|
|
|
|
try:
|
|
from exploitdb_client_local import ExploitDBLocalClient
|
|
client = ExploitDBLocalClient(db)
|
|
|
|
# Get sync status
|
|
status = await client.get_exploitdb_sync_status()
|
|
|
|
# Get quality distribution from ExploitDB data
|
|
quality_distribution = {}
|
|
from sqlalchemy import text
|
|
cves_with_exploitdb = db.query(CVE).filter(
|
|
text("poc_data::text LIKE '%\"exploitdb\"%'")
|
|
).all()
|
|
|
|
for cve in cves_with_exploitdb:
|
|
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
|
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
|
for exploit in exploits:
|
|
quality_tier = exploit.get('quality_analysis', {}).get('quality_tier', 'unknown')
|
|
quality_distribution[quality_tier] = quality_distribution.get(quality_tier, 0) + 1
|
|
|
|
# Get category distribution
|
|
category_distribution = {}
|
|
for cve in cves_with_exploitdb:
|
|
if cve.poc_data and 'exploitdb' in cve.poc_data:
|
|
exploits = cve.poc_data['exploitdb'].get('exploits', [])
|
|
for exploit in exploits:
|
|
category = exploit.get('category', 'unknown')
|
|
category_distribution[category] = category_distribution.get(category, 0) + 1
|
|
|
|
return {
|
|
"exploitdb_sync_status": status,
|
|
"quality_distribution": quality_distribution,
|
|
"category_distribution": category_distribution,
|
|
"total_exploitdb_cves": len(cves_with_exploitdb),
|
|
"total_exploits": sum(
|
|
len(cve.poc_data.get('exploitdb', {}).get('exploits', []))
|
|
for cve in cves_with_exploitdb
|
|
if cve.poc_data and 'exploitdb' in cve.poc_data
|
|
)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting ExploitDB stats: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/api/github-poc-stats")
|
|
async def get_github_poc_stats(db: Session = Depends(get_db)):
|
|
"""Get GitHub PoC-related statistics"""
|
|
|
|
try:
|
|
# Get basic statistics
|
|
github_poc_rules = db.query(SigmaRule).filter(SigmaRule.poc_source == 'github_poc').count()
|
|
cves_with_github_pocs = db.query(CVE).filter(
|
|
CVE.poc_data.isnot(None), # Check if poc_data exists
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
|
).count()
|
|
|
|
# Get quality distribution
|
|
quality_distribution = {}
|
|
try:
|
|
quality_results = db.query(
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_tier').label('tier'),
|
|
func.count().label('count')
|
|
).filter(
|
|
CVE.poc_data.isnot(None),
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
|
).group_by('tier').all()
|
|
|
|
for tier, count in quality_results:
|
|
if tier:
|
|
quality_distribution[tier] = count
|
|
except Exception as e:
|
|
logger.warning(f"Error getting quality distribution: {e}")
|
|
quality_distribution = {}
|
|
|
|
# Calculate average quality score
|
|
try:
|
|
avg_quality = db.query(
|
|
func.avg(func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer))
|
|
).filter(
|
|
CVE.poc_data.isnot(None),
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'source') == 'github_poc'
|
|
).scalar() or 0
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating average quality: {e}")
|
|
avg_quality = 0
|
|
|
|
return {
|
|
'github_poc_rules': github_poc_rules,
|
|
'cves_with_github_pocs': cves_with_github_pocs,
|
|
'quality_distribution': quality_distribution,
|
|
'average_quality_score': float(avg_quality) if avg_quality else 0,
|
|
'source': 'github_poc'
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting GitHub PoC stats: {e}")
|
|
return {"error": str(e)}
|
|
|
|
@app.get("/api/github-poc-status")
|
|
async def get_github_poc_status(db: Session = Depends(get_db)):
|
|
"""Get GitHub PoC data availability status"""
|
|
|
|
try:
|
|
client = GitHubPoCClient(db)
|
|
|
|
# Check if GitHub PoC data is available
|
|
github_poc_data = client.load_github_poc_data()
|
|
|
|
return {
|
|
'github_poc_data_available': len(github_poc_data) > 0,
|
|
'total_cves_with_pocs': len(github_poc_data),
|
|
'sample_cve_ids': list(github_poc_data.keys())[:10], # First 10 CVE IDs
|
|
'data_path': str(client.github_poc_path),
|
|
'path_exists': client.github_poc_path.exists()
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error checking GitHub PoC status: {e}")
|
|
return {"error": str(e)}
|
|
|
|
@app.get("/api/cisa-kev-stats")
|
|
async def get_cisa_kev_stats(db: Session = Depends(get_db)):
|
|
"""Get CISA KEV-related statistics"""
|
|
|
|
try:
|
|
from cisa_kev_client import CISAKEVClient
|
|
client = CISAKEVClient(db)
|
|
|
|
# Get sync status
|
|
status = await client.get_kev_sync_status()
|
|
|
|
# Get threat level distribution from CISA KEV data
|
|
threat_level_distribution = {}
|
|
from sqlalchemy import text
|
|
cves_with_kev = db.query(CVE).filter(
|
|
text("poc_data::text LIKE '%\"cisa_kev\"%'")
|
|
).all()
|
|
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
threat_level = vuln_data.get('threat_level', 'unknown')
|
|
threat_level_distribution[threat_level] = threat_level_distribution.get(threat_level, 0) + 1
|
|
|
|
# Get vulnerability category distribution
|
|
category_distribution = {}
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
category = vuln_data.get('vulnerability_category', 'unknown')
|
|
category_distribution[category] = category_distribution.get(category, 0) + 1
|
|
|
|
# Get ransomware usage statistics
|
|
ransomware_stats = {'known': 0, 'unknown': 0}
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
ransomware_use = vuln_data.get('known_ransomware_use', 'Unknown').lower()
|
|
if ransomware_use == 'known':
|
|
ransomware_stats['known'] += 1
|
|
else:
|
|
ransomware_stats['unknown'] += 1
|
|
|
|
# Calculate average threat score
|
|
threat_scores = []
|
|
for cve in cves_with_kev:
|
|
if cve.poc_data and 'cisa_kev' in cve.poc_data:
|
|
vuln_data = cve.poc_data['cisa_kev'].get('vulnerability_data', {})
|
|
threat_score = vuln_data.get('threat_score', 0)
|
|
if threat_score:
|
|
threat_scores.append(threat_score)
|
|
|
|
avg_threat_score = sum(threat_scores) / len(threat_scores) if threat_scores else 0
|
|
|
|
return {
|
|
"cisa_kev_sync_status": status,
|
|
"threat_level_distribution": threat_level_distribution,
|
|
"category_distribution": category_distribution,
|
|
"ransomware_stats": ransomware_stats,
|
|
"average_threat_score": round(avg_threat_score, 2),
|
|
"total_kev_cves": len(cves_with_kev),
|
|
"total_with_threat_scores": len(threat_scores)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting CISA KEV stats: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/api/bulk-jobs")
|
|
async def get_bulk_jobs(limit: int = 10, db: Session = Depends(get_db)):
|
|
"""Get bulk processing job status"""
|
|
|
|
jobs = db.query(BulkProcessingJob).order_by(
|
|
BulkProcessingJob.created_at.desc()
|
|
).limit(limit).all()
|
|
|
|
result = []
|
|
for job in jobs:
|
|
job_dict = {
|
|
'id': str(job.id),
|
|
'job_type': job.job_type,
|
|
'status': job.status,
|
|
'year': job.year,
|
|
'total_items': job.total_items,
|
|
'processed_items': job.processed_items,
|
|
'failed_items': job.failed_items,
|
|
'error_message': job.error_message,
|
|
'metadata': job.job_metadata,
|
|
'started_at': job.started_at,
|
|
'completed_at': job.completed_at,
|
|
'created_at': job.created_at
|
|
}
|
|
result.append(job_dict)
|
|
|
|
return result
|
|
|
|
@app.get("/api/bulk-status")
|
|
async def get_bulk_status(db: Session = Depends(get_db)):
|
|
"""Get comprehensive bulk processing status"""
|
|
|
|
try:
|
|
from bulk_seeder import BulkSeeder
|
|
seeder = BulkSeeder(db)
|
|
status = await seeder.get_seeding_status()
|
|
return status
|
|
except Exception as e:
|
|
logger.error(f"Error getting bulk status: {e}")
|
|
return {"error": str(e)}
|
|
|
|
@app.get("/api/poc-stats")
|
|
async def get_poc_stats(db: Session = Depends(get_db)):
|
|
"""Get PoC-related statistics"""
|
|
|
|
try:
|
|
from nomi_sec_client import NomiSecClient
|
|
client = NomiSecClient(db)
|
|
stats = await client.get_sync_status()
|
|
|
|
# Additional PoC statistics
|
|
high_quality_cves = db.query(CVE).filter(
|
|
CVE.poc_count > 0,
|
|
func.json_extract_path_text(CVE.poc_data, '0', 'quality_analysis', 'quality_score').cast(Integer) > 60
|
|
).count()
|
|
|
|
stats.update({
|
|
'high_quality_cves': high_quality_cves,
|
|
'avg_poc_count': db.query(func.avg(CVE.poc_count)).filter(CVE.poc_count > 0).scalar() or 0
|
|
})
|
|
|
|
return stats
|
|
except Exception as e:
|
|
logger.error(f"Error getting PoC stats: {e}")
|
|
return {"error": str(e)}
|
|
|
|
@app.get("/api/cve2capec-stats")
|
|
async def get_cve2capec_stats():
|
|
"""Get CVE2CAPEC MITRE ATT&CK mapping statistics"""
|
|
|
|
try:
|
|
client = CVE2CAPECClient()
|
|
stats = client.get_stats()
|
|
|
|
return {
|
|
"status": "success",
|
|
"data": stats,
|
|
"description": "CVE to MITRE ATT&CK technique mappings from CVE2CAPEC repository"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting CVE2CAPEC stats: {e}")
|
|
return {"error": str(e)}
|
|
|
|
@app.post("/api/regenerate-rules")
|
|
async def regenerate_sigma_rules(background_tasks: BackgroundTasks,
|
|
request: RuleRegenRequest,
|
|
db: Session = Depends(get_db)):
|
|
"""Regenerate SIGMA rules using enhanced nomi-sec data"""
|
|
|
|
async def regenerate_task():
|
|
try:
|
|
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
|
generator = EnhancedSigmaGenerator(db)
|
|
|
|
# Get CVEs with PoC data
|
|
cves_with_pocs = db.query(CVE).filter(CVE.poc_count > 0).all()
|
|
|
|
rules_generated = 0
|
|
rules_updated = 0
|
|
|
|
for cve in cves_with_pocs:
|
|
# Check if we should regenerate
|
|
existing_rule = db.query(SigmaRule).filter(
|
|
SigmaRule.cve_id == cve.cve_id
|
|
).first()
|
|
|
|
if existing_rule and existing_rule.poc_source == 'nomi_sec' and not request.force:
|
|
continue
|
|
|
|
# Generate enhanced rule
|
|
result = await generator.generate_enhanced_rule(cve)
|
|
|
|
if result['success']:
|
|
if existing_rule:
|
|
rules_updated += 1
|
|
else:
|
|
rules_generated += 1
|
|
|
|
logger.info(f"Rule regeneration completed: {rules_generated} new, {rules_updated} updated")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rule regeneration failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(regenerate_task)
|
|
|
|
return {
|
|
"message": "SIGMA rule regeneration started",
|
|
"status": "started",
|
|
"force": request.force
|
|
}
|
|
|
|
@app.post("/api/llm-enhanced-rules")
|
|
async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
"""Generate SIGMA rules using LLM API for enhanced analysis"""
|
|
|
|
# Parse request parameters
|
|
cve_id = request.get('cve_id')
|
|
force = request.get('force', False)
|
|
llm_provider = request.get('provider', os.getenv('LLM_PROVIDER'))
|
|
llm_model = request.get('model', os.getenv('LLM_MODEL'))
|
|
|
|
# Validation
|
|
if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id):
|
|
raise HTTPException(status_code=400, detail="Invalid CVE ID format")
|
|
|
|
async def llm_generation_task():
|
|
"""Background task for LLM-enhanced rule generation"""
|
|
try:
|
|
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
|
|
|
generator = EnhancedSigmaGenerator(db, llm_provider, llm_model)
|
|
|
|
# Process specific CVE or all CVEs with PoC data
|
|
if cve_id:
|
|
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if not cve:
|
|
logger.error(f"CVE {cve_id} not found")
|
|
return
|
|
|
|
cves_to_process = [cve]
|
|
else:
|
|
# Process CVEs with PoC data that either have no rules or force update
|
|
query = db.query(CVE).filter(CVE.poc_count > 0)
|
|
|
|
if not force:
|
|
# Only process CVEs without existing LLM-generated rules
|
|
existing_llm_rules = db.query(SigmaRule).filter(
|
|
SigmaRule.detection_type.like('llm_%')
|
|
).all()
|
|
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
|
|
cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids]
|
|
else:
|
|
cves_to_process = query.all()
|
|
|
|
logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}")
|
|
|
|
rules_generated = 0
|
|
rules_updated = 0
|
|
failures = 0
|
|
|
|
for cve in cves_to_process:
|
|
try:
|
|
# Check if CVE has sufficient PoC data
|
|
if not cve.poc_data or not cve.poc_count:
|
|
logger.debug(f"Skipping {cve.cve_id} - no PoC data")
|
|
continue
|
|
|
|
# Generate LLM-enhanced rule
|
|
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
|
|
|
if result.get('success'):
|
|
if result.get('updated'):
|
|
rules_updated += 1
|
|
else:
|
|
rules_generated += 1
|
|
|
|
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
|
|
else:
|
|
failures += 1
|
|
logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}")
|
|
|
|
except Exception as e:
|
|
failures += 1
|
|
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
|
|
continue
|
|
|
|
logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM-enhanced rule generation failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(llm_generation_task)
|
|
|
|
return {
|
|
"message": "LLM-enhanced SIGMA rule generation started",
|
|
"status": "started",
|
|
"cve_id": cve_id,
|
|
"force": force,
|
|
"provider": llm_provider,
|
|
"model": llm_model,
|
|
"note": "Requires appropriate LLM API key to be set"
|
|
}
|
|
|
|
@app.get("/api/llm-status")
|
|
async def get_llm_status():
|
|
"""Check LLM API availability status"""
|
|
try:
|
|
from llm_client import LLMClient
|
|
|
|
# Get current provider configuration
|
|
provider = os.getenv('LLM_PROVIDER')
|
|
model = os.getenv('LLM_MODEL')
|
|
|
|
client = LLMClient(provider=provider, model=model)
|
|
provider_info = client.get_provider_info()
|
|
|
|
# Get all available providers
|
|
all_providers = LLMClient.get_available_providers()
|
|
|
|
return {
|
|
"current_provider": provider_info,
|
|
"available_providers": all_providers,
|
|
"status": "ready" if client.is_available() else "unavailable"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error checking LLM status: {e}")
|
|
return {
|
|
"current_provider": {"provider": "unknown", "available": False},
|
|
"available_providers": [],
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
@app.post("/api/llm-switch")
|
|
async def switch_llm_provider(request: dict):
|
|
"""Switch LLM provider and model"""
|
|
try:
|
|
from llm_client import LLMClient
|
|
|
|
provider = request.get('provider')
|
|
model = request.get('model')
|
|
|
|
if not provider:
|
|
raise HTTPException(status_code=400, detail="Provider is required")
|
|
|
|
# Validate provider
|
|
if provider not in LLMClient.SUPPORTED_PROVIDERS:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
|
|
|
|
# Test the new configuration
|
|
client = LLMClient(provider=provider, model=model)
|
|
|
|
if not client.is_available():
|
|
raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured")
|
|
|
|
# Update environment variables (note: this only affects the current session)
|
|
os.environ['LLM_PROVIDER'] = provider
|
|
if model:
|
|
os.environ['LLM_MODEL'] = model
|
|
|
|
provider_info = client.get_provider_info()
|
|
|
|
return {
|
|
"message": f"Switched to {provider}",
|
|
"provider_info": provider_info,
|
|
"status": "success"
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error switching LLM provider: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/cancel-job/{job_id}")
|
|
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
|
"""Cancel a running job"""
|
|
try:
|
|
# Find the job in the database
|
|
job = db.query(BulkProcessingJob).filter(BulkProcessingJob.id == job_id).first()
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
if job.status not in ['pending', 'running']:
|
|
raise HTTPException(status_code=400, detail=f"Cannot cancel job with status: {job.status}")
|
|
|
|
# Set cancellation flag
|
|
job_cancellation_flags[job_id] = True
|
|
|
|
# Update job status
|
|
job.status = 'cancelled'
|
|
job.cancelled_at = datetime.utcnow()
|
|
job.error_message = "Job cancelled by user"
|
|
|
|
db.commit()
|
|
|
|
logger.info(f"Job {job_id} cancellation requested")
|
|
|
|
return {
|
|
"message": f"Job {job_id} cancellation requested",
|
|
"status": "cancelled",
|
|
"job_id": job_id
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling job {job_id}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/api/running-jobs")
|
|
async def get_running_jobs(db: Session = Depends(get_db)):
|
|
"""Get all currently running jobs"""
|
|
try:
|
|
jobs = db.query(BulkProcessingJob).filter(
|
|
BulkProcessingJob.status.in_(['pending', 'running'])
|
|
).order_by(BulkProcessingJob.created_at.desc()).all()
|
|
|
|
result = []
|
|
for job in jobs:
|
|
result.append({
|
|
'id': str(job.id),
|
|
'job_type': job.job_type,
|
|
'status': job.status,
|
|
'year': job.year,
|
|
'total_items': job.total_items,
|
|
'processed_items': job.processed_items,
|
|
'failed_items': job.failed_items,
|
|
'error_message': job.error_message,
|
|
'started_at': job.started_at,
|
|
'created_at': job.created_at,
|
|
'can_cancel': job.status in ['pending', 'running']
|
|
})
|
|
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting running jobs: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/ollama-pull-model")
|
|
async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks):
|
|
"""Pull an Ollama model"""
|
|
try:
|
|
from llm_client import LLMClient
|
|
|
|
model = request.get('model')
|
|
if not model:
|
|
raise HTTPException(status_code=400, detail="Model name is required")
|
|
|
|
# Create a background task to pull the model
|
|
def pull_model_task():
|
|
try:
|
|
client = LLMClient(provider='ollama', model=model)
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
|
|
|
if client._pull_ollama_model(base_url, model):
|
|
logger.info(f"Successfully pulled Ollama model: {model}")
|
|
else:
|
|
logger.error(f"Failed to pull Ollama model: {model}")
|
|
except Exception as e:
|
|
logger.error(f"Error in model pull task: {e}")
|
|
|
|
background_tasks.add_task(pull_model_task)
|
|
|
|
return {
|
|
"message": f"Started pulling model {model}",
|
|
"status": "started",
|
|
"model": model
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting model pull: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/api/ollama-models")
|
|
async def get_ollama_models():
|
|
"""Get available Ollama models"""
|
|
try:
|
|
from llm_client import LLMClient
|
|
|
|
client = LLMClient(provider='ollama')
|
|
available_models = client._get_ollama_available_models()
|
|
|
|
return {
|
|
"available_models": available_models,
|
|
"total_models": len(available_models),
|
|
"status": "success"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting Ollama models: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ============================================================================
|
|
# SCHEDULER ENDPOINTS
|
|
# ============================================================================
|
|
|
|
class SchedulerControlRequest(BaseModel):
|
|
action: str # 'start', 'stop', 'restart'
|
|
|
|
class JobControlRequest(BaseModel):
|
|
job_name: str
|
|
action: str # 'enable', 'disable', 'trigger'
|
|
|
|
class UpdateScheduleRequest(BaseModel):
|
|
job_name: str
|
|
schedule: str # Cron expression
|
|
|
|
@app.get("/api/scheduler/status")
|
|
async def get_scheduler_status():
|
|
"""Get scheduler status and job information"""
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
|
|
scheduler = get_scheduler()
|
|
status = scheduler.get_job_status()
|
|
|
|
return {
|
|
"scheduler_status": status,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting scheduler status: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/scheduler/control")
|
|
async def control_scheduler(request: SchedulerControlRequest):
|
|
"""Control scheduler (start/stop/restart)"""
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
|
|
scheduler = get_scheduler()
|
|
|
|
if request.action == 'start':
|
|
scheduler.start()
|
|
message = "Scheduler started"
|
|
elif request.action == 'stop':
|
|
scheduler.stop()
|
|
message = "Scheduler stopped"
|
|
elif request.action == 'restart':
|
|
scheduler.stop()
|
|
scheduler.start()
|
|
message = "Scheduler restarted"
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
|
|
|
return {
|
|
"message": message,
|
|
"action": request.action,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error controlling scheduler: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/scheduler/job/control")
|
|
async def control_job(request: JobControlRequest):
|
|
"""Control individual jobs (enable/disable/trigger)"""
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
|
|
scheduler = get_scheduler()
|
|
|
|
if request.action == 'enable':
|
|
success = scheduler.enable_job(request.job_name)
|
|
message = f"Job {request.job_name} enabled" if success else f"Job {request.job_name} not found"
|
|
elif request.action == 'disable':
|
|
success = scheduler.disable_job(request.job_name)
|
|
message = f"Job {request.job_name} disabled" if success else f"Job {request.job_name} not found"
|
|
elif request.action == 'trigger':
|
|
success = scheduler.trigger_job(request.job_name)
|
|
message = f"Job {request.job_name} triggered" if success else f"Failed to trigger job {request.job_name}"
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
|
|
|
return {
|
|
"message": message,
|
|
"job_name": request.job_name,
|
|
"action": request.action,
|
|
"success": success,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error controlling job: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/scheduler/job/schedule")
|
|
async def update_job_schedule(request: UpdateScheduleRequest):
|
|
"""Update job schedule"""
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
|
|
scheduler = get_scheduler()
|
|
success = scheduler.update_job_schedule(request.job_name, request.schedule)
|
|
|
|
if success:
|
|
# Get updated job info
|
|
job_status = scheduler.get_job_status(request.job_name)
|
|
return {
|
|
"message": f"Schedule updated for job {request.job_name}",
|
|
"job_name": request.job_name,
|
|
"new_schedule": request.schedule,
|
|
"next_run": job_status.get("next_run"),
|
|
"success": True,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Failed to update schedule for job {request.job_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating job schedule: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/api/scheduler/job/{job_name}")
|
|
async def get_job_status(job_name: str):
|
|
"""Get status of a specific job"""
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
|
|
scheduler = get_scheduler()
|
|
status = scheduler.get_job_status(job_name)
|
|
|
|
if "error" in status:
|
|
raise HTTPException(status_code=404, detail=status["error"])
|
|
|
|
return {
|
|
"job_status": status,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting job status: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/scheduler/reload")
|
|
async def reload_scheduler_config():
|
|
"""Reload scheduler configuration from file"""
|
|
try:
|
|
from job_scheduler import get_scheduler
|
|
|
|
scheduler = get_scheduler()
|
|
success = scheduler.reload_config()
|
|
|
|
if success:
|
|
return {
|
|
"message": "Scheduler configuration reloaded successfully",
|
|
"success": True,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Failed to reload configuration")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error reloading scheduler config: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|