auto_sigma_rule_generator/backend/nomi_sec_client.py

505 lines
No EOL
20 KiB
Python

"""
Nomi-sec PoC-in-GitHub Integration Client
Interfaces with the nomi-sec PoC-in-GitHub API for curated exploit data
"""
import aiohttp
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
import time
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NomiSecClient:
"""Client for interacting with nomi-sec PoC-in-GitHub API"""
def __init__(self, db_session: Session):
self.db_session = db_session
self.base_url = "https://poc-in-github.motikan2010.net/api/v1"
self.rss_url = "https://poc-in-github.motikan2010.net/rss"
# Rate limiting
self.rate_limit_delay = 1.0 # 1 second between requests
self.last_request_time = 0
# Cache for recently fetched data
self.cache = {}
self.cache_ttl = 300 # 5 minutes
async def _make_request(self, session: aiohttp.ClientSession,
url: str, params: dict = None) -> Optional[dict]:
"""Make a rate-limited request to the API"""
try:
# Rate limiting
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.rate_limit_delay:
await asyncio.sleep(self.rate_limit_delay - time_since_last)
async with session.get(url, params=params, timeout=30) as response:
self.last_request_time = time.time()
if response.status == 200:
return await response.json()
else:
logger.warning(f"API request failed: {response.status} for {url}")
return None
except Exception as e:
logger.error(f"Error making request to {url}: {e}")
return None
async def get_pocs_for_cve(self, cve_id: str) -> List[dict]:
"""Get all PoC repositories for a specific CVE"""
cache_key = f"cve_{cve_id}"
# Check cache
if cache_key in self.cache:
cached_data, timestamp = self.cache[cache_key]
if time.time() - timestamp < self.cache_ttl:
return cached_data
async with aiohttp.ClientSession() as session:
params = {"cve_id": cve_id}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
pocs = data["pocs"]
# Cache the result
self.cache[cache_key] = (pocs, time.time())
logger.info(f"Found {len(pocs)} PoCs for {cve_id}")
return pocs
else:
logger.info(f"No PoCs found for {cve_id}")
return []
async def get_recent_pocs(self, limit: int = 100) -> List[dict]:
"""Get recent PoCs from the API"""
async with aiohttp.ClientSession() as session:
params = {"limit": limit, "sort": "created_at"}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
return data["pocs"]
else:
return []
async def get_high_quality_pocs(self, min_stars: int = 5, limit: int = 100) -> List[dict]:
"""Get high-quality PoCs sorted by star count"""
async with aiohttp.ClientSession() as session:
params = {"limit": limit, "sort": "stargazers_count"}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
# Filter by star count
filtered_pocs = [
poc for poc in data["pocs"]
if int(poc.get("stargazers_count", "0")) >= min_stars
]
return filtered_pocs
else:
return []
async def search_pocs(self, query: str, limit: int = 50) -> List[dict]:
"""Search for PoCs using a query string"""
async with aiohttp.ClientSession() as session:
params = {"limit": limit, "q": query}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
return data["pocs"]
else:
return []
def analyze_poc_quality(self, poc: dict) -> dict:
"""Analyze the quality of a PoC repository"""
quality_score = 0
factors = {}
# Star count factor (0-40 points)
stars = int(poc.get("stargazers_count", "0"))
star_score = min(stars * 2, 40) # 2 points per star, max 40
quality_score += star_score
factors["star_score"] = star_score
# Recency factor (0-20 points)
try:
updated_at = datetime.fromisoformat(poc.get("updated_at", "").replace('Z', '+00:00'))
days_old = (datetime.now(updated_at.tzinfo) - updated_at).days
recency_score = max(20 - (days_old // 30), 0) # Lose 1 point per month
quality_score += recency_score
factors["recency_score"] = recency_score
except:
factors["recency_score"] = 0
# Description quality factor (0-15 points)
description = poc.get("description", "")
desc_score = 0
if description:
desc_score = min(len(description) // 10, 15) # 1 point per 10 chars, max 15
quality_score += desc_score
factors["description_score"] = desc_score
# Vulnerability description factor (0-15 points)
vuln_desc = poc.get("vuln_description", "")
vuln_score = 0
if vuln_desc:
vuln_score = min(len(vuln_desc) // 20, 15) # 1 point per 20 chars, max 15
quality_score += vuln_score
factors["vuln_description_score"] = vuln_score
# Repository name relevance factor (0-10 points)
repo_name = poc.get("name", "").lower()
cve_id = poc.get("cve_id", "").lower()
name_score = 0
if cve_id and cve_id.replace("-", "") in repo_name.replace("-", ""):
name_score = 10
elif any(keyword in repo_name for keyword in ["exploit", "poc", "cve", "vuln"]):
name_score = 5
quality_score += name_score
factors["name_relevance_score"] = name_score
return {
"quality_score": quality_score,
"factors": factors,
"quality_tier": self._get_quality_tier(quality_score)
}
def _get_quality_tier(self, score: int) -> str:
"""Get quality tier based on score"""
if score >= 80:
return "excellent"
elif score >= 60:
return "good"
elif score >= 40:
return "fair"
elif score >= 20:
return "poor"
else:
return "very_poor"
def extract_exploit_indicators(self, poc: dict) -> dict:
"""Extract exploit indicators from PoC metadata"""
indicators = {
"processes": [],
"files": [],
"network": [],
"registry": [],
"commands": [],
"urls": [],
"techniques": []
}
# Extract from description and vulnerability description
text_sources = [
poc.get("description", ""),
poc.get("vuln_description", ""),
poc.get("name", "")
]
full_text = " ".join(text_sources).lower()
# Enhanced process patterns
process_patterns = [
r'\b(cmd\.exe|powershell\.exe|bash|sh|python\.exe|java\.exe|node\.exe)\b',
r'\b(createprocess|shellexecute|system|winexec)\b',
r'\b(reverse.?shell|bind.?shell|web.?shell)\b',
r'\b(mshta\.exe|rundll32\.exe|regsvr32\.exe|wscript\.exe|cscript\.exe)\b',
r'\b(certutil\.exe|bitsadmin\.exe|schtasks\.exe)\b'
]
for pattern in process_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["processes"].extend(matches)
# Enhanced file patterns
file_patterns = [
r'\b([a-zA-Z]:\\[^\\\s]+\\[^\\\s]+\.[a-zA-Z0-9]+)\b', # Windows paths
r'\b(/[^/\s]+/[^/\s]+\.[a-zA-Z0-9]+)\b', # Unix paths
r'\b(\w+\.(exe|dll|bat|ps1|py|sh|jar|php|jsp|asp|aspx))\b', # Common executable files
r'\b(\w+\.(txt|log|tmp|temp|dat|bin))\b', # Common data files
r'\b(payload|exploit|shell|backdoor|trojan)\b' # Malicious file indicators
]
for pattern in file_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
if matches and isinstance(matches[0], tuple):
indicators["files"].extend([m[0] for m in matches])
else:
indicators["files"].extend(matches)
# Enhanced network patterns
network_patterns = [
r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b', # IP addresses
r'\b((?:\d{1,5})|(?:0x[a-fA-F0-9]{1,4}))\b', # Ports
r'\b(http[s]?://[^\s<>"]+)\b', # URLs
r'\b([a-zA-Z0-9-]+\.[a-zA-Z]{2,})\b' # Domain names
]
for pattern in network_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
if 'http' in pattern:
indicators["urls"].extend(matches)
elif '\\d{1,3}\\.' in pattern or '\\d{1,5}' in pattern:
indicators["network"].extend(matches)
else:
indicators["network"].extend(matches)
# Enhanced command patterns
command_patterns = [
r'\b(curl|wget|nc|netcat|ncat|telnet)\b',
r'\b(whoami|id|uname|systeminfo|ipconfig|ifconfig)\b',
r'\b(cat|type|more|less|head|tail)\b',
r'\b(echo|print|printf)\b',
r'\b(base64|decode|encode)\b',
r'\b(invoke|iex|downloadstring)\b',
r'\b(net\s+user|net\s+localgroup)\b',
r'\b(sc\s+create|sc\s+start)\b'
]
for pattern in command_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["commands"].extend(matches)
# Registry patterns
registry_patterns = [
r'\b(HKEY_[A-Z_]+)\b',
r'\b(HKLM|HKCU|HKCR|HKU|HKCC)\b',
r'\b(reg\s+add|reg\s+query|reg\s+delete)\b',
r'\b(SOFTWARE\\\\[^\\\s]+)\b',
r'\b(SYSTEM\\\\[^\\\s]+)\b'
]
for pattern in registry_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["registry"].extend(matches)
# Clean up and deduplicate
for key in indicators:
# Remove empty strings and duplicates
indicators[key] = list(set([item for item in indicators[key] if item and len(item.strip()) > 2]))
# Limit to reasonable number of indicators
indicators[key] = indicators[key][:15]
return indicators
async def sync_cve_pocs(self, cve_id: str) -> dict:
"""Synchronize PoC data for a specific CVE"""
from main import CVE, SigmaRule
# Get existing CVE
cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.warning(f"CVE {cve_id} not found in database")
return {"error": "CVE not found"}
# Fetch PoCs from nomi-sec API
pocs = await self.get_pocs_for_cve(cve_id)
if not pocs:
logger.info(f"No PoCs found for {cve_id}")
return {"cve_id": cve_id, "pocs_found": 0}
# Analyze and store PoC data
poc_data = []
github_repos = []
total_quality_score = 0
for poc in pocs:
quality_analysis = self.analyze_poc_quality(poc)
exploit_indicators = self.extract_exploit_indicators(poc)
poc_entry = {
"id": poc.get("id"),
"name": poc.get("name"),
"owner": poc.get("owner"),
"full_name": poc.get("full_name"),
"html_url": poc.get("html_url"),
"description": poc.get("description"),
"stargazers_count": int(poc.get("stargazers_count", "0")),
"created_at": poc.get("created_at"),
"updated_at": poc.get("updated_at"),
"quality_analysis": quality_analysis,
"exploit_indicators": exploit_indicators
}
poc_data.append(poc_entry)
github_repos.append(poc.get("html_url", ""))
total_quality_score += quality_analysis["quality_score"]
# Update CVE with PoC data
cve.poc_count = len(pocs)
cve.poc_data = poc_data
cve.updated_at = datetime.utcnow()
# Update or create SIGMA rule with enhanced PoC data
sigma_rule = self.db_session.query(SigmaRule).filter(
SigmaRule.cve_id == cve_id
).first()
if sigma_rule:
sigma_rule.poc_source = 'nomi_sec'
sigma_rule.poc_quality_score = total_quality_score // len(pocs) if pocs else 0
sigma_rule.nomi_sec_data = {
"total_pocs": len(pocs),
"average_quality": total_quality_score // len(pocs) if pocs else 0,
"best_poc": max(poc_data, key=lambda x: x["quality_analysis"]["quality_score"]) if poc_data else None,
"total_stars": sum(p["stargazers_count"] for p in poc_data)
}
sigma_rule.github_repos = github_repos
sigma_rule.updated_at = datetime.utcnow()
# Extract best exploit indicators
best_indicators = {}
for poc in poc_data:
for key, values in poc["exploit_indicators"].items():
if key not in best_indicators:
best_indicators[key] = []
best_indicators[key].extend(values)
# Deduplicate and store
for key in best_indicators:
best_indicators[key] = list(set(best_indicators[key]))
sigma_rule.exploit_indicators = json.dumps(best_indicators)
self.db_session.commit()
logger.info(f"Synchronized {len(pocs)} PoCs for {cve_id}")
return {
"cve_id": cve_id,
"pocs_found": len(pocs),
"total_quality_score": total_quality_score,
"average_quality": total_quality_score // len(pocs) if pocs else 0,
"github_repos": github_repos
}
async def bulk_sync_all_cves(self, batch_size: int = 100, cancellation_flag: Optional[callable] = None) -> dict:
"""Synchronize PoC data for all CVEs in database"""
from main import CVE, BulkProcessingJob
# Create bulk processing job
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='running',
started_at=datetime.utcnow(),
job_metadata={'batch_size': batch_size}
)
self.db_session.add(job)
self.db_session.commit()
total_processed = 0
total_found = 0
results = []
try:
# Get all CVEs from database
cves = self.db_session.query(CVE).all()
job.total_items = len(cves)
self.db_session.commit()
# Process in batches
for i in range(0, len(cves), batch_size):
# Check for cancellation before each batch
if cancellation_flag and cancellation_flag():
logger.info("Bulk sync cancelled by user")
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
break
batch = cves[i:i + batch_size]
for cve in batch:
# Check for cancellation before each CVE
if cancellation_flag and cancellation_flag():
logger.info("Bulk sync cancelled by user")
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
break
try:
result = await self.sync_cve_pocs(cve.cve_id)
total_processed += 1
if result.get("pocs_found", 0) > 0:
total_found += result["pocs_found"]
results.append(result)
job.processed_items += 1
# Small delay to avoid overwhelming the API
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"Error syncing PoCs for {cve.cve_id}: {e}")
job.failed_items += 1
# Break out of outer loop if cancelled
if job.status == 'cancelled':
break
# Commit after each batch
self.db_session.commit()
logger.info(f"Processed batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}")
# Update job status (only if not cancelled)
if job.status != 'cancelled':
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_pocs_found': total_found,
'cves_with_pocs': len(results)
})
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Bulk PoC sync job failed: {e}")
finally:
self.db_session.commit()
return {
'job_id': str(job.id),
'status': job.status,
'total_processed': total_processed,
'total_pocs_found': total_found,
'cves_with_pocs': len(results)
}
async def get_sync_status(self) -> dict:
"""Get synchronization status"""
from main import CVE, SigmaRule
# Count CVEs with PoC data
total_cves = self.db_session.query(CVE).count()
cves_with_pocs = self.db_session.query(CVE).filter(CVE.poc_count > 0).count()
# Count SIGMA rules with nomi-sec data
total_rules = self.db_session.query(SigmaRule).count()
rules_with_nomi_sec = self.db_session.query(SigmaRule).filter(
SigmaRule.poc_source == 'nomi_sec'
).count()
return {
'total_cves': total_cves,
'cves_with_pocs': cves_with_pocs,
'poc_coverage': (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0,
'total_rules': total_rules,
'rules_with_nomi_sec': rules_with_nomi_sec,
'nomi_sec_coverage': (rules_with_nomi_sec / total_rules * 100) if total_rules > 0 else 0
}