952 lines
39 KiB
Python
952 lines
39 KiB
Python
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from sqlalchemy import create_engine, Column, String, Text, DECIMAL, TIMESTAMP, Boolean, ARRAY
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
from sqlalchemy.dialects.postgresql import UUID
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
import requests
|
|
import json
|
|
import re
|
|
import os
|
|
from typing import List, Optional
|
|
from pydantic import BaseModel
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import base64
|
|
from github import Github
|
|
from urllib.parse import urlparse
|
|
import hashlib
|
|
|
|
# Database setup
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db")
|
|
engine = create_engine(DATABASE_URL)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
# Database Models
|
|
class CVE(Base):
|
|
__tablename__ = "cves"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
cve_id = Column(String(20), unique=True, nullable=False)
|
|
description = Column(Text)
|
|
cvss_score = Column(DECIMAL(3, 1))
|
|
severity = Column(String(20))
|
|
published_date = Column(TIMESTAMP)
|
|
modified_date = Column(TIMESTAMP)
|
|
affected_products = Column(ARRAY(String))
|
|
reference_urls = Column(ARRAY(String))
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class SigmaRule(Base):
|
|
__tablename__ = "sigma_rules"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
cve_id = Column(String(20))
|
|
rule_name = Column(String(255), nullable=False)
|
|
rule_content = Column(Text, nullable=False)
|
|
detection_type = Column(String(50))
|
|
log_source = Column(String(100))
|
|
confidence_level = Column(String(20))
|
|
auto_generated = Column(Boolean, default=True)
|
|
exploit_based = Column(Boolean, default=False)
|
|
github_repos = Column(ARRAY(String))
|
|
exploit_indicators = Column(Text) # JSON string of extracted indicators
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
updated_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
class RuleTemplate(Base):
|
|
__tablename__ = "rule_templates"
|
|
|
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
template_name = Column(String(255), nullable=False)
|
|
template_content = Column(Text, nullable=False)
|
|
applicable_product_patterns = Column(ARRAY(String))
|
|
description = Column(Text)
|
|
created_at = Column(TIMESTAMP, default=datetime.utcnow)
|
|
|
|
# Pydantic models
|
|
class CVEResponse(BaseModel):
|
|
id: str
|
|
cve_id: str
|
|
description: Optional[str] = None
|
|
cvss_score: Optional[float] = None
|
|
severity: Optional[str] = None
|
|
published_date: Optional[datetime] = None
|
|
affected_products: Optional[List[str]] = None
|
|
reference_urls: Optional[List[str]] = None
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class SigmaRuleResponse(BaseModel):
|
|
id: str
|
|
cve_id: str
|
|
rule_name: str
|
|
rule_content: str
|
|
detection_type: Optional[str] = None
|
|
log_source: Optional[str] = None
|
|
confidence_level: Optional[str] = None
|
|
auto_generated: bool = True
|
|
exploit_based: bool = False
|
|
github_repos: Optional[List[str]] = None
|
|
exploit_indicators: Optional[str] = None
|
|
created_at: datetime
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
# GitHub Exploit Analysis Service
|
|
class GitHubExploitAnalyzer:
|
|
def __init__(self):
|
|
self.github_token = os.getenv("GITHUB_TOKEN")
|
|
self.github = Github(self.github_token) if self.github_token else None
|
|
|
|
async def search_exploits_for_cve(self, cve_id: str) -> List[dict]:
|
|
"""Search GitHub for exploit code related to a CVE"""
|
|
if not self.github:
|
|
print(f"No GitHub token configured, skipping exploit search for {cve_id}")
|
|
return []
|
|
|
|
try:
|
|
print(f"Searching GitHub for exploits for {cve_id}")
|
|
|
|
# Search queries to find exploit code
|
|
search_queries = [
|
|
f"{cve_id} exploit",
|
|
f"{cve_id} poc",
|
|
f"{cve_id} vulnerability",
|
|
f'"{cve_id}" exploit code',
|
|
f"{cve_id.replace('-', '_')} exploit"
|
|
]
|
|
|
|
exploits = []
|
|
seen_repos = set()
|
|
|
|
for query in search_queries[:2]: # Limit to 2 queries to avoid rate limits
|
|
try:
|
|
# Search repositories
|
|
repos = self.github.search_repositories(
|
|
query=query,
|
|
sort="updated",
|
|
order="desc"
|
|
)
|
|
|
|
# Get top 5 results per query
|
|
for repo in repos[:5]:
|
|
if repo.full_name in seen_repos:
|
|
continue
|
|
seen_repos.add(repo.full_name)
|
|
|
|
# Analyze repository
|
|
exploit_info = await self._analyze_repository(repo, cve_id)
|
|
if exploit_info:
|
|
exploits.append(exploit_info)
|
|
|
|
if len(exploits) >= 10: # Limit total exploits
|
|
break
|
|
|
|
if len(exploits) >= 10:
|
|
break
|
|
|
|
except Exception as e:
|
|
print(f"Error searching GitHub with query '{query}': {str(e)}")
|
|
continue
|
|
|
|
print(f"Found {len(exploits)} potential exploits for {cve_id}")
|
|
return exploits
|
|
|
|
except Exception as e:
|
|
print(f"Error searching GitHub for {cve_id}: {str(e)}")
|
|
return []
|
|
|
|
async def _analyze_repository(self, repo, cve_id: str) -> Optional[dict]:
|
|
"""Analyze a GitHub repository for exploit code"""
|
|
try:
|
|
# Check if repo name or description mentions the CVE
|
|
repo_text = f"{repo.name} {repo.description or ''}".lower()
|
|
if cve_id.lower() not in repo_text and cve_id.replace('-', '_').lower() not in repo_text:
|
|
return None
|
|
|
|
# Get repository contents
|
|
exploit_files = []
|
|
indicators = {
|
|
'processes': set(),
|
|
'files': set(),
|
|
'registry': set(),
|
|
'network': set(),
|
|
'commands': set(),
|
|
'powershell': set(),
|
|
'urls': set()
|
|
}
|
|
|
|
try:
|
|
contents = repo.get_contents("")
|
|
for content in contents[:20]: # Limit files to analyze
|
|
if content.type == "file" and self._is_exploit_file(content.name):
|
|
file_analysis = await self._analyze_file_content(repo, content, cve_id)
|
|
if file_analysis:
|
|
exploit_files.append(file_analysis)
|
|
# Merge indicators
|
|
for key, values in file_analysis.get('indicators', {}).items():
|
|
if key in indicators:
|
|
indicators[key].update(values)
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing repo contents for {repo.full_name}: {str(e)}")
|
|
|
|
if not exploit_files:
|
|
return None
|
|
|
|
return {
|
|
'repo_name': repo.full_name,
|
|
'repo_url': repo.html_url,
|
|
'description': repo.description,
|
|
'language': repo.language,
|
|
'stars': repo.stargazers_count,
|
|
'updated': repo.updated_at.isoformat(),
|
|
'files': exploit_files,
|
|
'indicators': {k: list(v) for k, v in indicators.items()}
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing repository {repo.full_name}: {str(e)}")
|
|
return None
|
|
|
|
def _is_exploit_file(self, filename: str) -> bool:
|
|
"""Check if a file is likely to contain exploit code"""
|
|
exploit_extensions = ['.py', '.ps1', '.sh', '.c', '.cpp', '.js', '.rb', '.pl', '.php', '.java']
|
|
exploit_names = ['exploit', 'poc', 'payload', 'shell', 'reverse', 'bind', 'attack']
|
|
|
|
filename_lower = filename.lower()
|
|
|
|
# Check extension
|
|
if not any(filename_lower.endswith(ext) for ext in exploit_extensions):
|
|
return False
|
|
|
|
# Check filename for exploit-related terms
|
|
return any(term in filename_lower for term in exploit_names) or 'cve' in filename_lower
|
|
|
|
async def _analyze_file_content(self, repo, file_content, cve_id: str) -> Optional[dict]:
|
|
"""Analyze individual file content for exploit indicators"""
|
|
try:
|
|
if file_content.size > 100000: # Skip files larger than 100KB
|
|
return None
|
|
|
|
# Decode file content
|
|
content = file_content.decoded_content.decode('utf-8', errors='ignore')
|
|
|
|
# Check if file actually mentions the CVE
|
|
if cve_id.lower() not in content.lower() and cve_id.replace('-', '_').lower() not in content.lower():
|
|
return None
|
|
|
|
indicators = self._extract_indicators_from_code(content, file_content.name)
|
|
|
|
if not any(indicators.values()):
|
|
return None
|
|
|
|
return {
|
|
'filename': file_content.name,
|
|
'path': file_content.path,
|
|
'size': file_content.size,
|
|
'indicators': indicators
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing file {file_content.name}: {str(e)}")
|
|
return None
|
|
|
|
def _extract_indicators_from_code(self, content: str, filename: str) -> dict:
|
|
"""Extract security indicators from exploit code"""
|
|
indicators = {
|
|
'processes': set(),
|
|
'files': set(),
|
|
'registry': set(),
|
|
'network': set(),
|
|
'commands': set(),
|
|
'powershell': set(),
|
|
'urls': set()
|
|
}
|
|
|
|
# Process patterns
|
|
process_patterns = [
|
|
r'CreateProcess[AW]?\s*\(\s*["\']([^"\']+)["\']',
|
|
r'ShellExecute[AW]?\s*\([^,]*,\s*["\']([^"\']+)["\']',
|
|
r'system\s*\(\s*["\']([^"\']+)["\']',
|
|
r'exec\s*\(\s*["\']([^"\']+)["\']',
|
|
r'subprocess\.(?:call|run|Popen)\s*\(\s*["\']([^"\']+)["\']'
|
|
]
|
|
|
|
# File patterns
|
|
file_patterns = [
|
|
r'(?:fopen|CreateFile|WriteFile|ReadFile)\s*\(\s*["\']([^"\']+\.[a-zA-Z0-9]+)["\']',
|
|
r'(?:copy|move|del|rm)\s+["\']?([^\s"\']+\.[a-zA-Z0-9]+)["\']?',
|
|
r'\\\\[^\\]+\\[^\\]+\\([^\\]+\.[a-zA-Z0-9]+)',
|
|
r'[C-Z]:\\\\[^\\]+\\\\([^\\]+\.[a-zA-Z0-9]+)'
|
|
]
|
|
|
|
# Registry patterns
|
|
registry_patterns = [
|
|
r'(?:RegOpenKey|RegSetValue|RegCreateKey)\s*\([^,]*,\s*["\']([^"\']+)["\']',
|
|
r'HKEY_[A-Z_]+\\\\([^"\'\\]+)',
|
|
r'reg\s+add\s+["\']?([^"\'\\]+\\\\[^"\']+)["\']?'
|
|
]
|
|
|
|
# Network patterns
|
|
network_patterns = [
|
|
r'(?:connect|bind|listen)\s*\([^,]*,\s*(\d+)',
|
|
r'socket\.connect\s*\(\s*\(["\']?([^"\']+)["\']?,\s*(\d+)\)',
|
|
r'(?:http|https|ftp)://([^\s"\'<>]+)',
|
|
r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)'
|
|
]
|
|
|
|
# PowerShell patterns
|
|
powershell_patterns = [
|
|
r'(?:powershell|pwsh)\s+(?:-[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
|
|
r'Invoke-(?:Expression|Command|WebRequest|RestMethod)\s+["\']?([^"\']+)["\']?',
|
|
r'Start-Process\s+["\']?([^"\']+)["\']?',
|
|
r'Get-Process\s+["\']?([^"\']+)["\']?'
|
|
]
|
|
|
|
# Command patterns
|
|
command_patterns = [
|
|
r'(?:cmd|command)\s+(?:/[a-zA-Z]+\s+)*["\']?([^"\']+)["\']?',
|
|
r'(?:ping|nslookup|netstat|tasklist|wmic)\s+([^\s"\']+)',
|
|
r'(?:net|sc|schtasks)\s+[a-zA-Z]+\s+([^\s"\']+)'
|
|
]
|
|
|
|
# Extract indicators using regex patterns
|
|
patterns = {
|
|
'processes': process_patterns,
|
|
'files': file_patterns,
|
|
'registry': registry_patterns,
|
|
'powershell': powershell_patterns,
|
|
'commands': command_patterns
|
|
}
|
|
|
|
for category, pattern_list in patterns.items():
|
|
for pattern in pattern_list:
|
|
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
|
|
for match in matches:
|
|
if isinstance(match, tuple):
|
|
indicators[category].add(match[0])
|
|
else:
|
|
indicators[category].add(match)
|
|
|
|
# Special handling for network indicators
|
|
for pattern in network_patterns:
|
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
|
for match in matches:
|
|
if isinstance(match, tuple):
|
|
if len(match) >= 2:
|
|
indicators['network'].add(f"{match[0]}:{match[1]}")
|
|
else:
|
|
indicators['network'].add(match[0])
|
|
else:
|
|
indicators['network'].add(match)
|
|
|
|
# Convert sets to lists and filter out empty/invalid indicators
|
|
cleaned_indicators = {}
|
|
for key, values in indicators.items():
|
|
cleaned_values = [v for v in values if v and len(v.strip()) > 2 and len(v) < 200]
|
|
if cleaned_values:
|
|
cleaned_indicators[key] = cleaned_values[:10] # Limit to 10 per category
|
|
|
|
return cleaned_indicators
|
|
class CVESigmaService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.nvd_api_key = os.getenv("NVD_API_KEY")
|
|
|
|
async def fetch_recent_cves(self, days_back: int = 7):
|
|
"""Fetch recent CVEs from NVD API"""
|
|
end_date = datetime.utcnow()
|
|
start_date = end_date - timedelta(days=days_back)
|
|
|
|
url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
|
params = {
|
|
"pubStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
|
|
"pubEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z",
|
|
"resultsPerPage": 100
|
|
}
|
|
|
|
headers = {}
|
|
if self.nvd_api_key:
|
|
headers["apiKey"] = self.nvd_api_key
|
|
|
|
try:
|
|
response = requests.get(url, params=params, headers=headers, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
new_cves = []
|
|
for vuln in data.get("vulnerabilities", []):
|
|
cve_data = vuln.get("cve", {})
|
|
cve_id = cve_data.get("id")
|
|
|
|
# Check if CVE already exists
|
|
existing = self.db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if existing:
|
|
continue
|
|
|
|
# Extract CVE information
|
|
description = ""
|
|
if cve_data.get("descriptions"):
|
|
description = cve_data["descriptions"][0].get("value", "")
|
|
|
|
cvss_score = None
|
|
severity = None
|
|
if cve_data.get("metrics", {}).get("cvssMetricV31"):
|
|
cvss_data = cve_data["metrics"]["cvssMetricV31"][0]
|
|
cvss_score = cvss_data.get("cvssData", {}).get("baseScore")
|
|
severity = cvss_data.get("cvssData", {}).get("baseSeverity")
|
|
|
|
affected_products = []
|
|
if cve_data.get("configurations"):
|
|
for config in cve_data["configurations"]:
|
|
for node in config.get("nodes", []):
|
|
for cpe_match in node.get("cpeMatch", []):
|
|
if cpe_match.get("vulnerable"):
|
|
affected_products.append(cpe_match.get("criteria", ""))
|
|
|
|
reference_urls = []
|
|
if cve_data.get("references"):
|
|
reference_urls = [ref.get("url", "") for ref in cve_data["references"]]
|
|
|
|
cve_obj = CVE(
|
|
cve_id=cve_id,
|
|
description=description,
|
|
cvss_score=cvss_score,
|
|
severity=severity,
|
|
published_date=datetime.fromisoformat(cve_data.get("published", "").replace("Z", "+00:00")),
|
|
modified_date=datetime.fromisoformat(cve_data.get("lastModified", "").replace("Z", "+00:00")),
|
|
affected_products=affected_products,
|
|
reference_urls=reference_urls
|
|
)
|
|
|
|
self.db.add(cve_obj)
|
|
new_cves.append(cve_obj)
|
|
|
|
self.db.commit()
|
|
return new_cves
|
|
|
|
except Exception as e:
|
|
print(f"Error fetching CVEs: {str(e)}")
|
|
return []
|
|
|
|
def generate_sigma_rule(self, cve: CVE) -> Optional[SigmaRule]:
|
|
"""Generate SIGMA rule based on CVE data"""
|
|
if not cve.description:
|
|
return None
|
|
|
|
# Analyze CVE to determine appropriate template
|
|
description_lower = cve.description.lower()
|
|
affected_products = [p.lower() for p in (cve.affected_products or [])]
|
|
|
|
template = self._select_template(description_lower, affected_products)
|
|
if not template:
|
|
return None
|
|
|
|
# Generate rule content
|
|
rule_content = self._populate_template(cve, template)
|
|
if not rule_content:
|
|
return None
|
|
|
|
# Determine detection type and confidence
|
|
detection_type = self._determine_detection_type(description_lower)
|
|
confidence_level = self._calculate_confidence(cve)
|
|
|
|
sigma_rule = SigmaRule(
|
|
cve_id=cve.cve_id,
|
|
rule_name=f"CVE-{cve.cve_id.split('-')[1]}-{cve.cve_id.split('-')[2]} Detection",
|
|
rule_content=rule_content,
|
|
detection_type=detection_type,
|
|
log_source=template.template_name.lower().replace(" ", "_"),
|
|
confidence_level=confidence_level,
|
|
auto_generated=True
|
|
)
|
|
|
|
self.db.add(sigma_rule)
|
|
return sigma_rule
|
|
|
|
def _select_template(self, description: str, affected_products: List[str], exploit_indicators: dict = None):
|
|
"""Select appropriate SIGMA rule template based on CVE and exploit analysis"""
|
|
templates = self.db.query(RuleTemplate).all()
|
|
|
|
# If we have exploit indicators, use them to determine the best template
|
|
if exploit_indicators:
|
|
if exploit_indicators.get('powershell'):
|
|
powershell_template = next((t for t in templates if "PowerShell" in t.template_name), None)
|
|
if powershell_template:
|
|
return powershell_template
|
|
|
|
if exploit_indicators.get('network'):
|
|
network_template = next((t for t in templates if "Network Connection" in t.template_name), None)
|
|
if network_template:
|
|
return network_template
|
|
|
|
if exploit_indicators.get('files'):
|
|
file_template = next((t for t in templates if "File Modification" in t.template_name), None)
|
|
if file_template:
|
|
return file_template
|
|
|
|
if exploit_indicators.get('processes') or exploit_indicators.get('commands'):
|
|
process_template = next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
if process_template:
|
|
return process_template
|
|
|
|
# Fallback to original logic
|
|
if any("windows" in p or "microsoft" in p for p in affected_products):
|
|
if "process" in description or "execution" in description:
|
|
return next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
elif "network" in description or "remote" in description:
|
|
return next((t for t in templates if "Network Connection" in t.template_name), None)
|
|
elif "file" in description or "write" in description:
|
|
return next((t for t in templates if "File Modification" in t.template_name), None)
|
|
|
|
# Default to process execution template
|
|
return next((t for t in templates if "Process Execution" in t.template_name), None)
|
|
|
|
def _populate_template(self, cve: CVE, template: RuleTemplate, exploit_indicators: dict = None) -> str:
|
|
"""Populate template with CVE-specific data and exploit indicators"""
|
|
try:
|
|
# Use exploit indicators if available, otherwise extract from description
|
|
if exploit_indicators:
|
|
suspicious_processes = exploit_indicators.get('processes', []) + exploit_indicators.get('commands', [])
|
|
suspicious_ports = []
|
|
file_patterns = exploit_indicators.get('files', [])
|
|
|
|
# Extract ports from network indicators
|
|
for net_indicator in exploit_indicators.get('network', []):
|
|
if ':' in str(net_indicator):
|
|
try:
|
|
port = int(str(net_indicator).split(':')[-1])
|
|
suspicious_ports.append(port)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
# Fallback to original extraction
|
|
suspicious_processes = self._extract_suspicious_indicators(cve.description, "process")
|
|
suspicious_ports = self._extract_suspicious_indicators(cve.description, "port")
|
|
file_patterns = self._extract_suspicious_indicators(cve.description, "file")
|
|
|
|
# Determine severity level
|
|
level = "high" if cve.cvss_score and cve.cvss_score >= 7.0 else "medium"
|
|
|
|
# Create enhanced description
|
|
enhanced_description = cve.description[:200] + "..." if len(cve.description) > 200 else cve.description
|
|
if exploit_indicators:
|
|
enhanced_description += " [Enhanced with GitHub exploit analysis]"
|
|
|
|
# Build tags
|
|
tags = [f"attack.{self._get_mitre_technique(cve.description, exploit_indicators)}", cve.cve_id.lower()]
|
|
if exploit_indicators:
|
|
tags.append("exploit.github")
|
|
|
|
rule_content = template.template_content.format(
|
|
title=f"CVE-{cve.cve_id} {'Exploit-Based ' if exploit_indicators else ''}Detection",
|
|
description=enhanced_description,
|
|
rule_id=str(uuid.uuid4()),
|
|
date=datetime.utcnow().strftime("%Y/%m/%d"),
|
|
cve_url=f"https://nvd.nist.gov/vuln/detail/{cve.cve_id}",
|
|
cve_id=cve.cve_id.lower(),
|
|
tags="\n - ".join(tags),
|
|
suspicious_processes=suspicious_processes or ["suspicious.exe", "malware.exe"],
|
|
suspicious_ports=suspicious_ports or [4444, 8080, 9999],
|
|
file_patterns=file_patterns or ["temp", "malware", "exploit"],
|
|
level=level
|
|
)
|
|
|
|
return rule_content
|
|
|
|
except Exception as e:
|
|
print(f"Error populating template: {str(e)}")
|
|
return None
|
|
|
|
def _get_mitre_technique(self, description: str, exploit_indicators: dict = None) -> str:
|
|
"""Map CVE and exploit indicators to MITRE ATT&CK techniques"""
|
|
desc_lower = description.lower()
|
|
|
|
# Check exploit indicators first
|
|
if exploit_indicators:
|
|
if exploit_indicators.get('powershell'):
|
|
return "t1059.001" # PowerShell
|
|
elif exploit_indicators.get('commands'):
|
|
return "t1059.003" # Windows Command Shell
|
|
elif exploit_indicators.get('network'):
|
|
return "t1071.001" # Web Protocols
|
|
elif exploit_indicators.get('files'):
|
|
return "t1105" # Ingress Tool Transfer
|
|
elif exploit_indicators.get('processes'):
|
|
return "t1106" # Native API
|
|
|
|
# Fallback to description analysis
|
|
if "powershell" in desc_lower:
|
|
return "t1059.001"
|
|
elif "command" in desc_lower or "cmd" in desc_lower:
|
|
return "t1059.003"
|
|
elif "network" in desc_lower or "remote" in desc_lower:
|
|
return "t1071.001"
|
|
elif "file" in desc_lower or "upload" in desc_lower:
|
|
return "t1105"
|
|
elif "process" in desc_lower or "execution" in desc_lower:
|
|
return "t1106"
|
|
else:
|
|
return "execution" # Generic
|
|
|
|
def _extract_suspicious_indicators(self, description: str, indicator_type: str) -> List:
|
|
"""Extract suspicious indicators from CVE description"""
|
|
if indicator_type == "process":
|
|
# Look for executable names or process patterns
|
|
exe_pattern = re.findall(r'(\w+\.exe)', description, re.IGNORECASE)
|
|
return exe_pattern[:5] if exe_pattern else None
|
|
|
|
elif indicator_type == "port":
|
|
# Look for port numbers
|
|
port_pattern = re.findall(r'port\s+(\d+)', description, re.IGNORECASE)
|
|
return [int(p) for p in port_pattern[:3]] if port_pattern else None
|
|
|
|
elif indicator_type == "file":
|
|
# Look for file extensions or paths
|
|
file_pattern = re.findall(r'(\w+\.\w{3,4})', description, re.IGNORECASE)
|
|
return file_pattern[:5] if file_pattern else None
|
|
|
|
return None
|
|
|
|
def _determine_detection_type(self, description: str, exploit_indicators: dict = None) -> str:
|
|
"""Determine detection type based on CVE description and exploit indicators"""
|
|
if exploit_indicators:
|
|
if exploit_indicators.get('powershell'):
|
|
return "powershell"
|
|
elif exploit_indicators.get('network'):
|
|
return "network"
|
|
elif exploit_indicators.get('files'):
|
|
return "file"
|
|
elif exploit_indicators.get('processes') or exploit_indicators.get('commands'):
|
|
return "process"
|
|
|
|
# Fallback to original logic
|
|
if "remote" in description or "network" in description:
|
|
return "network"
|
|
elif "process" in description or "execution" in description:
|
|
return "process"
|
|
elif "file" in description or "filesystem" in description:
|
|
return "file"
|
|
else:
|
|
return "general"
|
|
|
|
def _calculate_confidence(self, cve: CVE, exploit_based: bool = False) -> str:
|
|
"""Calculate confidence level for the generated rule"""
|
|
base_confidence = 0
|
|
|
|
# CVSS score contributes to confidence
|
|
if cve.cvss_score:
|
|
if cve.cvss_score >= 9.0:
|
|
base_confidence += 3
|
|
elif cve.cvss_score >= 7.0:
|
|
base_confidence += 2
|
|
else:
|
|
base_confidence += 1
|
|
|
|
# Exploit-based rules get higher confidence
|
|
if exploit_based:
|
|
base_confidence += 2
|
|
|
|
# Map to confidence levels
|
|
if base_confidence >= 4:
|
|
return "high"
|
|
elif base_confidence >= 2:
|
|
return "medium"
|
|
else:
|
|
return "low"
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
# Background task to fetch CVEs and generate rules
|
|
async def background_cve_fetch():
|
|
retry_count = 0
|
|
max_retries = 3
|
|
|
|
while True:
|
|
try:
|
|
db = SessionLocal()
|
|
service = CVESigmaService(db)
|
|
current_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
|
|
print(f"[{current_time}] Starting CVE fetch cycle...")
|
|
|
|
# Use a longer initial period (30 days) to find CVEs
|
|
new_cves = await service.fetch_recent_cves(days_back=30)
|
|
|
|
if new_cves:
|
|
print(f"Found {len(new_cves)} new CVEs, generating SIGMA rules...")
|
|
rules_generated = 0
|
|
for cve in new_cves:
|
|
try:
|
|
sigma_rule = service.generate_sigma_rule(cve)
|
|
if sigma_rule:
|
|
rules_generated += 1
|
|
print(f"Generated SIGMA rule for {cve.cve_id}")
|
|
else:
|
|
print(f"Could not generate rule for {cve.cve_id} - insufficient data")
|
|
except Exception as e:
|
|
print(f"Error generating rule for {cve.cve_id}: {str(e)}")
|
|
|
|
db.commit()
|
|
print(f"Successfully generated {rules_generated} SIGMA rules")
|
|
retry_count = 0 # Reset retry count on success
|
|
else:
|
|
print("No new CVEs found in this cycle")
|
|
# After first successful run, reduce to 7 days for regular updates
|
|
if retry_count == 0:
|
|
print("Switching to 7-day lookback for future runs...")
|
|
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
retry_count += 1
|
|
print(f"Background task error (attempt {retry_count}/{max_retries}): {str(e)}")
|
|
if retry_count >= max_retries:
|
|
print(f"Max retries reached, waiting longer before next attempt...")
|
|
await asyncio.sleep(1800) # Wait 30 minutes on repeated failures
|
|
retry_count = 0
|
|
else:
|
|
await asyncio.sleep(300) # Wait 5 minutes before retry
|
|
continue
|
|
|
|
# Wait 1 hour before next fetch (or 30 minutes if there were errors)
|
|
wait_time = 3600 if retry_count == 0 else 1800
|
|
print(f"Next CVE fetch in {wait_time//60} minutes...")
|
|
await asyncio.sleep(wait_time)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Start background task
|
|
task = asyncio.create_task(background_cve_fetch())
|
|
yield
|
|
# Clean up
|
|
task.cancel()
|
|
|
|
# FastAPI app
|
|
app = FastAPI(title="CVE-SIGMA Auto Generator", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.get("/api/cves", response_model=List[CVEResponse])
|
|
async def get_cves(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
|
cves = db.query(CVE).order_by(CVE.published_date.desc()).offset(skip).limit(limit).all()
|
|
# Convert UUID to string for each CVE
|
|
result = []
|
|
for cve in cves:
|
|
cve_dict = {
|
|
'id': str(cve.id),
|
|
'cve_id': cve.cve_id,
|
|
'description': cve.description,
|
|
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
|
|
'severity': cve.severity,
|
|
'published_date': cve.published_date,
|
|
'affected_products': cve.affected_products,
|
|
'reference_urls': cve.reference_urls
|
|
}
|
|
result.append(CVEResponse(**cve_dict))
|
|
return result
|
|
|
|
@app.get("/api/cves/{cve_id}", response_model=CVEResponse)
|
|
async def get_cve(cve_id: str, db: Session = Depends(get_db)):
|
|
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
|
if not cve:
|
|
raise HTTPException(status_code=404, detail="CVE not found")
|
|
|
|
cve_dict = {
|
|
'id': str(cve.id),
|
|
'cve_id': cve.cve_id,
|
|
'description': cve.description,
|
|
'cvss_score': float(cve.cvss_score) if cve.cvss_score else None,
|
|
'severity': cve.severity,
|
|
'published_date': cve.published_date,
|
|
'affected_products': cve.affected_products,
|
|
'reference_urls': cve.reference_urls
|
|
}
|
|
return CVEResponse(**cve_dict)
|
|
|
|
@app.get("/api/sigma-rules", response_model=List[SigmaRuleResponse])
|
|
async def get_sigma_rules(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
|
rules = db.query(SigmaRule).order_by(SigmaRule.created_at.desc()).offset(skip).limit(limit).all()
|
|
# Convert UUID to string for each rule
|
|
result = []
|
|
for rule in rules:
|
|
rule_dict = {
|
|
'id': str(rule.id),
|
|
'cve_id': rule.cve_id,
|
|
'rule_name': rule.rule_name,
|
|
'rule_content': rule.rule_content,
|
|
'detection_type': rule.detection_type,
|
|
'log_source': rule.log_source,
|
|
'confidence_level': rule.confidence_level,
|
|
'auto_generated': rule.auto_generated,
|
|
'exploit_based': rule.exploit_based or False,
|
|
'github_repos': rule.github_repos or [],
|
|
'exploit_indicators': rule.exploit_indicators,
|
|
'created_at': rule.created_at
|
|
}
|
|
result.append(SigmaRuleResponse(**rule_dict))
|
|
return result
|
|
|
|
@app.get("/api/sigma-rules/{cve_id}", response_model=List[SigmaRuleResponse])
|
|
async def get_sigma_rules_by_cve(cve_id: str, db: Session = Depends(get_db)):
|
|
rules = db.query(SigmaRule).filter(SigmaRule.cve_id == cve_id).all()
|
|
# Convert UUID to string for each rule
|
|
result = []
|
|
for rule in rules:
|
|
rule_dict = {
|
|
'id': str(rule.id),
|
|
'cve_id': rule.cve_id,
|
|
'rule_name': rule.rule_name,
|
|
'rule_content': rule.rule_content,
|
|
'detection_type': rule.detection_type,
|
|
'log_source': rule.log_source,
|
|
'confidence_level': rule.confidence_level,
|
|
'auto_generated': rule.auto_generated,
|
|
'exploit_based': rule.exploit_based or False,
|
|
'github_repos': rule.github_repos or [],
|
|
'exploit_indicators': rule.exploit_indicators,
|
|
'created_at': rule.created_at
|
|
}
|
|
result.append(SigmaRuleResponse(**rule_dict))
|
|
return result
|
|
|
|
@app.post("/api/fetch-cves")
|
|
async def manual_fetch_cves(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
|
async def fetch_task():
|
|
try:
|
|
service = CVESigmaService(db)
|
|
print("Manual CVE fetch initiated...")
|
|
# Use 30 days for manual fetch to get more results
|
|
new_cves = await service.fetch_recent_cves(days_back=30)
|
|
|
|
rules_generated = 0
|
|
for cve in new_cves:
|
|
sigma_rule = service.generate_sigma_rule(cve)
|
|
if sigma_rule:
|
|
rules_generated += 1
|
|
|
|
db.commit()
|
|
print(f"Manual fetch complete: {len(new_cves)} CVEs, {rules_generated} rules generated")
|
|
except Exception as e:
|
|
print(f"Manual fetch error: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
background_tasks.add_task(fetch_task)
|
|
return {"message": "CVE fetch initiated (30-day lookback)", "status": "started"}
|
|
|
|
@app.get("/api/test-nvd")
|
|
async def test_nvd_connection():
|
|
"""Test endpoint to check NVD API connectivity"""
|
|
try:
|
|
# Test with a simple request using current date
|
|
end_date = datetime.utcnow()
|
|
start_date = end_date - timedelta(days=30)
|
|
|
|
url = "https://services.nvd.nist.gov/rest/json/cves/2.0/"
|
|
params = {
|
|
"lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
|
|
"lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000+00:00"),
|
|
"resultsPerPage": 5,
|
|
"startIndex": 0
|
|
}
|
|
|
|
headers = {
|
|
"User-Agent": "CVE-SIGMA-Generator/1.0",
|
|
"Accept": "application/json"
|
|
}
|
|
|
|
nvd_api_key = os.getenv("NVD_API_KEY")
|
|
if nvd_api_key:
|
|
headers["apiKey"] = nvd_api_key
|
|
|
|
print(f"Testing NVD API with URL: {url}")
|
|
print(f"Test params: {params}")
|
|
print(f"Test headers: {headers}")
|
|
|
|
response = requests.get(url, params=params, headers=headers, timeout=15)
|
|
|
|
result = {
|
|
"status": "success" if response.status_code == 200 else "error",
|
|
"status_code": response.status_code,
|
|
"has_api_key": bool(nvd_api_key),
|
|
"request_url": f"{url}?{requests.compat.urlencode(params)}",
|
|
"response_headers": dict(response.headers)
|
|
}
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
result.update({
|
|
"total_results": data.get("totalResults", 0),
|
|
"results_per_page": data.get("resultsPerPage", 0),
|
|
"vulnerabilities_returned": len(data.get("vulnerabilities", [])),
|
|
"message": "NVD API is accessible and returning data"
|
|
})
|
|
else:
|
|
result.update({
|
|
"error_message": response.text[:200],
|
|
"message": f"NVD API returned {response.status_code}"
|
|
})
|
|
|
|
# Try fallback without date filters if we get 404
|
|
if response.status_code == 404:
|
|
print("Trying fallback without date filters...")
|
|
fallback_params = {
|
|
"resultsPerPage": 5,
|
|
"startIndex": 0
|
|
}
|
|
fallback_response = requests.get(url, params=fallback_params, headers=headers, timeout=15)
|
|
result["fallback_status_code"] = fallback_response.status_code
|
|
|
|
if fallback_response.status_code == 200:
|
|
fallback_data = fallback_response.json()
|
|
result.update({
|
|
"fallback_success": True,
|
|
"fallback_total_results": fallback_data.get("totalResults", 0),
|
|
"message": "NVD API works without date filters"
|
|
})
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"NVD API test error: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"message": f"Failed to connect to NVD API: {str(e)}"
|
|
}
|
|
|
|
@app.get("/api/stats")
|
|
async def get_stats(db: Session = Depends(get_db)):
|
|
total_cves = db.query(CVE).count()
|
|
total_rules = db.query(SigmaRule).count()
|
|
recent_cves = db.query(CVE).filter(CVE.published_date >= datetime.utcnow() - timedelta(days=7)).count()
|
|
|
|
return {
|
|
"total_cves": total_cves,
|
|
"total_sigma_rules": total_rules,
|
|
"recent_cves_7_days": recent_cves
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|