auto_sigma_rule_generator/backend/nomi_sec_client.py

666 lines
No EOL
27 KiB
Python

"""
Nomi-sec PoC-in-GitHub Integration Client
Interfaces with the nomi-sec PoC-in-GitHub API for curated exploit data
"""
import aiohttp
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
import time
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NomiSecClient:
"""Client for interacting with nomi-sec PoC-in-GitHub API"""
def __init__(self, db_session: Session):
self.db_session = db_session
self.base_url = "https://poc-in-github.motikan2010.net/api/v1"
self.rss_url = "https://poc-in-github.motikan2010.net/rss"
# Optimized rate limiting
self.rate_limit_delay = 0.2 # 200ms between requests (5 requests/second)
self.last_request_time = 0
self.concurrent_requests = 3 # Allow concurrent requests
# Cache for recently fetched data
self.cache = {}
self.cache_ttl = 300 # 5 minutes
async def _make_request(self, session: aiohttp.ClientSession,
url: str, params: dict = None) -> Optional[dict]:
"""Make an optimized rate-limited request to the API"""
try:
# Optimized rate limiting
current_time = time.time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.rate_limit_delay:
await asyncio.sleep(self.rate_limit_delay - time_since_last)
async with session.get(url, params=params, timeout=10) as response:
self.last_request_time = time.time()
if response.status == 200:
return await response.json()
elif response.status == 429: # Rate limited
logger.warning(f"Rate limited, retrying after delay")
await asyncio.sleep(2.0)
return await self._make_request(session, url, params)
else:
logger.warning(f"API request failed: {response.status} for {url}")
return None
except asyncio.TimeoutError:
logger.warning(f"Request timeout for {url}")
return None
except Exception as e:
logger.error(f"Error making request to {url}: {e}")
return None
async def get_pocs_for_cve(self, cve_id: str, session: aiohttp.ClientSession = None) -> List[dict]:
"""Get all PoC repositories for a specific CVE with optimized session reuse"""
cache_key = f"cve_{cve_id}"
# Check cache
if cache_key in self.cache:
cached_data, timestamp = self.cache[cache_key]
if time.time() - timestamp < self.cache_ttl:
logger.debug(f"Cache hit for {cve_id}")
return cached_data
# Use provided session or create new one
if session:
params = {"cve_id": cve_id}
data = await self._make_request(session, self.base_url, params)
else:
# Optimized connector with connection pooling
connector = aiohttp.TCPConnector(
limit=100,
limit_per_host=10,
ttl_dns_cache=300,
use_dns_cache=True
)
async with aiohttp.ClientSession(connector=connector) as new_session:
params = {"cve_id": cve_id}
data = await self._make_request(new_session, self.base_url, params)
if data and "pocs" in data:
pocs = data["pocs"]
# Cache the result
self.cache[cache_key] = (pocs, time.time())
logger.debug(f"Found {len(pocs)} PoCs for {cve_id}")
return pocs
else:
logger.debug(f"No PoCs found for {cve_id}")
return []
async def get_recent_pocs(self, limit: int = 100) -> List[dict]:
"""Get recent PoCs from the API"""
async with aiohttp.ClientSession() as session:
params = {"limit": limit, "sort": "created_at"}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
return data["pocs"]
else:
return []
async def get_high_quality_pocs(self, min_stars: int = 5, limit: int = 100) -> List[dict]:
"""Get high-quality PoCs sorted by star count"""
async with aiohttp.ClientSession() as session:
params = {"limit": limit, "sort": "stargazers_count"}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
# Filter by star count
filtered_pocs = [
poc for poc in data["pocs"]
if int(poc.get("stargazers_count", "0")) >= min_stars
]
return filtered_pocs
else:
return []
async def search_pocs(self, query: str, limit: int = 50) -> List[dict]:
"""Search for PoCs using a query string"""
async with aiohttp.ClientSession() as session:
params = {"limit": limit, "q": query}
data = await self._make_request(session, self.base_url, params)
if data and "pocs" in data:
return data["pocs"]
else:
return []
def analyze_poc_quality(self, poc: dict) -> dict:
"""Analyze the quality of a PoC repository"""
quality_score = 0
factors = {}
# Star count factor (0-40 points)
stars = int(poc.get("stargazers_count", "0"))
star_score = min(stars * 2, 40) # 2 points per star, max 40
quality_score += star_score
factors["star_score"] = star_score
# Recency factor (0-20 points)
try:
updated_at = datetime.fromisoformat(poc.get("updated_at", "").replace('Z', '+00:00'))
days_old = (datetime.now(updated_at.tzinfo) - updated_at).days
recency_score = max(20 - (days_old // 30), 0) # Lose 1 point per month
quality_score += recency_score
factors["recency_score"] = recency_score
except:
factors["recency_score"] = 0
# Description quality factor (0-15 points)
description = poc.get("description", "")
desc_score = 0
if description:
desc_score = min(len(description) // 10, 15) # 1 point per 10 chars, max 15
quality_score += desc_score
factors["description_score"] = desc_score
# Vulnerability description factor (0-15 points)
vuln_desc = poc.get("vuln_description", "")
vuln_score = 0
if vuln_desc:
vuln_score = min(len(vuln_desc) // 20, 15) # 1 point per 20 chars, max 15
quality_score += vuln_score
factors["vuln_description_score"] = vuln_score
# Repository name relevance factor (0-10 points)
repo_name = poc.get("name", "").lower()
cve_id = poc.get("cve_id", "").lower()
name_score = 0
if cve_id and cve_id.replace("-", "") in repo_name.replace("-", ""):
name_score = 10
elif any(keyword in repo_name for keyword in ["exploit", "poc", "cve", "vuln"]):
name_score = 5
quality_score += name_score
factors["name_relevance_score"] = name_score
return {
"quality_score": quality_score,
"factors": factors,
"quality_tier": self._get_quality_tier(quality_score)
}
def _get_quality_tier(self, score: int) -> str:
"""Get quality tier based on score"""
if score >= 80:
return "excellent"
elif score >= 60:
return "good"
elif score >= 40:
return "fair"
elif score >= 20:
return "poor"
else:
return "very_poor"
def extract_exploit_indicators(self, poc: dict) -> dict:
"""Extract exploit indicators from PoC metadata"""
indicators = {
"processes": [],
"files": [],
"network": [],
"registry": [],
"commands": [],
"urls": [],
"techniques": []
}
# Extract from description and vulnerability description
text_sources = [
poc.get("description", ""),
poc.get("vuln_description", ""),
poc.get("name", "")
]
full_text = " ".join(text_sources).lower()
# Enhanced process patterns
process_patterns = [
r'\b(cmd\.exe|powershell\.exe|bash|sh|python\.exe|java\.exe|node\.exe)\b',
r'\b(createprocess|shellexecute|system|winexec)\b',
r'\b(reverse.?shell|bind.?shell|web.?shell)\b',
r'\b(mshta\.exe|rundll32\.exe|regsvr32\.exe|wscript\.exe|cscript\.exe)\b',
r'\b(certutil\.exe|bitsadmin\.exe|schtasks\.exe)\b'
]
for pattern in process_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["processes"].extend(matches)
# Enhanced file patterns
file_patterns = [
r'\b([a-zA-Z]:\\[^\\\s]+\\[^\\\s]+\.[a-zA-Z0-9]+)\b', # Windows paths
r'\b(/[^/\s]+/[^/\s]+\.[a-zA-Z0-9]+)\b', # Unix paths
r'\b(\w+\.(exe|dll|bat|ps1|py|sh|jar|php|jsp|asp|aspx))\b', # Common executable files
r'\b(\w+\.(txt|log|tmp|temp|dat|bin))\b', # Common data files
r'\b(payload|exploit|shell|backdoor|trojan)\b' # Malicious file indicators
]
for pattern in file_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
if matches and isinstance(matches[0], tuple):
indicators["files"].extend([m[0] for m in matches])
else:
indicators["files"].extend(matches)
# Enhanced network patterns
network_patterns = [
r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b', # IP addresses
r'\b((?:\d{1,5})|(?:0x[a-fA-F0-9]{1,4}))\b', # Ports
r'\b(http[s]?://[^\s<>"]+)\b', # URLs
r'\b([a-zA-Z0-9-]+\.[a-zA-Z]{2,})\b' # Domain names
]
for pattern in network_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
if 'http' in pattern:
indicators["urls"].extend(matches)
elif '\\d{1,3}\\.' in pattern or '\\d{1,5}' in pattern:
indicators["network"].extend(matches)
else:
indicators["network"].extend(matches)
# Enhanced command patterns
command_patterns = [
r'\b(curl|wget|nc|netcat|ncat|telnet)\b',
r'\b(whoami|id|uname|systeminfo|ipconfig|ifconfig)\b',
r'\b(cat|type|more|less|head|tail)\b',
r'\b(echo|print|printf)\b',
r'\b(base64|decode|encode)\b',
r'\b(invoke|iex|downloadstring)\b',
r'\b(net\s+user|net\s+localgroup)\b',
r'\b(sc\s+create|sc\s+start)\b'
]
for pattern in command_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["commands"].extend(matches)
# Registry patterns
registry_patterns = [
r'\b(HKEY_[A-Z_]+)\b',
r'\b(HKLM|HKCU|HKCR|HKU|HKCC)\b',
r'\b(reg\s+add|reg\s+query|reg\s+delete)\b',
r'\b(SOFTWARE\\\\[^\\\s]+)\b',
r'\b(SYSTEM\\\\[^\\\s]+)\b'
]
for pattern in registry_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["registry"].extend(matches)
# Clean up and deduplicate
for key in indicators:
# Remove empty strings and duplicates
indicators[key] = list(set([item for item in indicators[key] if item and len(item.strip()) > 2]))
# Limit to reasonable number of indicators
indicators[key] = indicators[key][:15]
return indicators
async def sync_cve_pocs(self, cve_id: str, session: aiohttp.ClientSession = None) -> dict:
"""Synchronize PoC data for a specific CVE with session reuse"""
from main import CVE, SigmaRule
# Get existing CVE
cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.warning(f"CVE {cve_id} not found in database")
return {"error": "CVE not found"}
# Fetch PoCs from nomi-sec API with session reuse
pocs = await self.get_pocs_for_cve(cve_id, session)
if not pocs:
logger.info(f"No PoCs found for {cve_id}")
return {"cve_id": cve_id, "pocs_found": 0}
# Analyze and store PoC data
poc_data = []
github_repos = []
total_quality_score = 0
for poc in pocs:
quality_analysis = self.analyze_poc_quality(poc)
exploit_indicators = self.extract_exploit_indicators(poc)
poc_entry = {
"id": poc.get("id"),
"name": poc.get("name"),
"owner": poc.get("owner"),
"full_name": poc.get("full_name"),
"html_url": poc.get("html_url"),
"description": poc.get("description"),
"stargazers_count": int(poc.get("stargazers_count", "0")),
"created_at": poc.get("created_at"),
"updated_at": poc.get("updated_at"),
"quality_analysis": quality_analysis,
"exploit_indicators": exploit_indicators
}
poc_data.append(poc_entry)
github_repos.append(poc.get("html_url", ""))
total_quality_score += quality_analysis["quality_score"]
# Update CVE with PoC data
cve.poc_count = len(pocs)
cve.poc_data = poc_data
cve.updated_at = datetime.utcnow()
# Update or create SIGMA rule with enhanced PoC data
sigma_rule = self.db_session.query(SigmaRule).filter(
SigmaRule.cve_id == cve_id
).first()
if sigma_rule:
sigma_rule.poc_source = 'nomi_sec'
sigma_rule.poc_quality_score = total_quality_score // len(pocs) if pocs else 0
sigma_rule.nomi_sec_data = {
"total_pocs": len(pocs),
"average_quality": total_quality_score // len(pocs) if pocs else 0,
"best_poc": max(poc_data, key=lambda x: x["quality_analysis"]["quality_score"]) if poc_data else None,
"total_stars": sum(p["stargazers_count"] for p in poc_data)
}
sigma_rule.github_repos = github_repos
sigma_rule.updated_at = datetime.utcnow()
# Extract best exploit indicators
best_indicators = {}
for poc in poc_data:
for key, values in poc["exploit_indicators"].items():
if key not in best_indicators:
best_indicators[key] = []
best_indicators[key].extend(values)
# Deduplicate and store
for key in best_indicators:
best_indicators[key] = list(set(best_indicators[key]))
sigma_rule.exploit_indicators = json.dumps(best_indicators)
self.db_session.commit()
logger.info(f"Synchronized {len(pocs)} PoCs for {cve_id}")
return {
"cve_id": cve_id,
"pocs_found": len(pocs),
"total_quality_score": total_quality_score,
"average_quality": total_quality_score // len(pocs) if pocs else 0,
"github_repos": github_repos
}
async def bulk_sync_all_cves(self, batch_size: int = 100, cancellation_flag: Optional[callable] = None) -> dict:
"""Synchronize PoC data for all CVEs in database"""
from main import CVE, BulkProcessingJob
# Create bulk processing job
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='running',
started_at=datetime.utcnow(),
job_metadata={'batch_size': batch_size}
)
self.db_session.add(job)
self.db_session.commit()
total_processed = 0
total_found = 0
results = []
try:
# Get all CVEs from database
cves = self.db_session.query(CVE).all()
job.total_items = len(cves)
self.db_session.commit()
# Process in batches
for i in range(0, len(cves), batch_size):
# Check for cancellation before each batch
if cancellation_flag and cancellation_flag():
logger.info("Bulk sync cancelled by user")
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
break
batch = cves[i:i + batch_size]
for cve in batch:
# Check for cancellation before each CVE
if cancellation_flag and cancellation_flag():
logger.info("Bulk sync cancelled by user")
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
break
try:
result = await self.sync_cve_pocs(cve.cve_id)
total_processed += 1
if result.get("pocs_found", 0) > 0:
total_found += result["pocs_found"]
results.append(result)
job.processed_items += 1
# Minimal delay for faster processing
await asyncio.sleep(0.05)
except Exception as e:
logger.error(f"Error syncing PoCs for {cve.cve_id}: {e}")
job.failed_items += 1
# Break out of outer loop if cancelled
if job.status == 'cancelled':
break
# Commit after each batch
self.db_session.commit()
logger.info(f"Processed batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}")
# Update job status (only if not cancelled)
if job.status != 'cancelled':
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_pocs_found': total_found,
'cves_with_pocs': len(results)
})
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Bulk PoC sync job failed: {e}")
finally:
self.db_session.commit()
return {
'job_id': str(job.id),
'status': job.status,
'total_processed': total_processed,
'total_pocs_found': total_found,
'cves_with_pocs': len(results)
}
async def bulk_sync_poc_data(self, batch_size: int = 50, max_cves: int = None,
force_resync: bool = False) -> dict:
"""Optimized bulk synchronization of PoC data with performance improvements"""
from main import CVE, SigmaRule, BulkProcessingJob
import asyncio
from datetime import datetime, timedelta
# Create job tracking
job = BulkProcessingJob(
job_type='nomi_sec_sync',
status='running',
started_at=datetime.utcnow(),
job_metadata={'batch_size': batch_size, 'max_cves': max_cves, 'force_resync': force_resync}
)
self.db_session.add(job)
self.db_session.commit()
try:
# Get CVEs that need PoC sync - optimized query
query = self.db_session.query(CVE)
if not force_resync:
# Skip CVEs that were recently synced or already have nomi-sec data
recent_cutoff = datetime.utcnow() - timedelta(days=7)
query = query.filter(
or_(
CVE.poc_data.is_(None),
and_(
CVE.updated_at < recent_cutoff,
CVE.poc_count == 0
)
)
)
# Prioritize recent CVEs and high CVSS scores
query = query.order_by(
CVE.published_date.desc(),
CVE.cvss_score.desc().nullslast()
)
if max_cves:
query = query.limit(max_cves)
cves = query.all()
job.total_items = len(cves)
self.db_session.commit()
logger.info(f"Starting optimized nomi-sec sync for {len(cves)} CVEs")
total_processed = 0
total_found = 0
concurrent_semaphore = asyncio.Semaphore(self.concurrent_requests)
# Create shared session with optimized settings
connector = aiohttp.TCPConnector(
limit=50,
limit_per_host=10,
ttl_dns_cache=300,
use_dns_cache=True,
keepalive_timeout=30
)
async with aiohttp.ClientSession(connector=connector) as shared_session:
async def process_cve_batch(cve_batch):
"""Process a batch of CVEs concurrently with shared session"""
async def process_single_cve(cve):
async with concurrent_semaphore:
try:
result = await self.sync_cve_pocs(cve.cve_id, shared_session)
return result
except Exception as e:
logger.error(f"Error syncing {cve.cve_id}: {e}")
return {'error': str(e), 'cve_id': cve.cve_id}
# Process batch concurrently
tasks = [process_single_cve(cve) for cve in cve_batch]
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
return batch_results
# Process in optimized batches
for i in range(0, len(cves), batch_size):
batch = cves[i:i + batch_size]
# Process batch concurrently
batch_results = await process_cve_batch(batch)
for result in batch_results:
if isinstance(result, Exception):
logger.error(f"Exception in batch processing: {result}")
job.failed_items += 1
elif isinstance(result, dict) and 'error' not in result:
total_processed += 1
if result.get('pocs_found', 0) > 0:
total_found += result['pocs_found']
job.processed_items += 1
else:
job.failed_items += 1
# Commit progress every batch
self.db_session.commit()
logger.info(f"Processed batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}, "
f"found {total_found} PoCs so far")
# Small delay between batches
await asyncio.sleep(0.05)
# Update job completion
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_pocs_found': total_found,
'processing_time_seconds': (job.completed_at - job.started_at).total_seconds()
})
self.db_session.commit()
logger.info(f"Nomi-sec sync completed: {total_processed} CVEs processed, {total_found} PoCs found")
return {
'job_id': str(job.id),
'status': 'completed',
'total_processed': total_processed,
'total_pocs_found': total_found,
'processing_time': (job.completed_at - job.started_at).total_seconds()
}
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
self.db_session.commit()
logger.error(f"Nomi-sec sync failed: {e}")
return {
'job_id': str(job.id),
'status': 'failed',
'error': str(e)
}
async def get_sync_status(self) -> dict:
"""Get synchronization status"""
from main import CVE, SigmaRule
# Count CVEs with PoC data
total_cves = self.db_session.query(CVE).count()
cves_with_pocs = self.db_session.query(CVE).filter(CVE.poc_count > 0).count()
# Count SIGMA rules with nomi-sec data
total_rules = self.db_session.query(SigmaRule).count()
rules_with_nomi_sec = self.db_session.query(SigmaRule).filter(
SigmaRule.poc_source == 'nomi_sec'
).count()
return {
'total_cves': total_cves,
'cves_with_pocs': cves_with_pocs,
'poc_coverage': (cves_with_pocs / total_cves * 100) if total_cves > 0 else 0,
'total_rules': total_rules,
'rules_with_nomi_sec': rules_with_nomi_sec,
'nomi_sec_coverage': (rules_with_nomi_sec / total_rules * 100) if total_rules > 0 else 0
}