auto_sigma_rule_generator/backend/mcdevitt_poc_client.py
bpmcdevitt a6fb367ed4 refactor: modularize backend architecture for improved maintainability
- Extract database models from monolithic main.py (2,373 lines) into organized modules
- Implement service layer pattern with dedicated business logic classes
- Split API endpoints into modular FastAPI routers by functionality
- Add centralized configuration management with environment variable handling
- Create proper separation of concerns across data, service, and presentation layers

**Architecture Changes:**
- models/: SQLAlchemy database models (CVE, SigmaRule, RuleTemplate, BulkProcessingJob)
- config/: Centralized settings and database configuration
- services/: Business logic (CVEService, SigmaRuleService, GitHubExploitAnalyzer)
- routers/: Modular API endpoints (cves, sigma_rules, bulk_operations, llm_operations)
- schemas/: Pydantic request/response models

**Key Improvements:**
- 95% reduction in main.py size (2,373 → 120 lines)
- Updated 15+ backend files with proper import structure
- Eliminated circular dependencies and tight coupling
- Enhanced testability with isolated service components
- Better code organization for team collaboration

**Backward Compatibility:**
- All API endpoints maintain same URLs and behavior
- Zero breaking changes to existing functionality
- Database schema unchanged
- Environment variables preserved

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-14 17:51:23 -05:00

593 lines
No EOL
25 KiB
Python

"""
GitHub PoC Collector Integration Client
Reads JSON files from github_poc_collector and fetches GitHub repo contents for SIGMA rule generation
"""
import aiohttp
import asyncio
import json
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from pathlib import Path
import re
import base64
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GitHubPoCClient:
"""Client for processing GitHub PoC collector data and fetching GitHub contents"""
def __init__(self, db_session: Session, github_token: Optional[str] = None):
self.db_session = db_session
self.github_token = github_token or os.getenv('GITHUB_TOKEN')
self.base_url = "https://api.github.com"
# Rate limiting - GitHub API: 5000 requests/hour with token, 60 without
self.rate_limit_delay = 0.8 if self.github_token else 60.0 # seconds
self.last_request_time = 0
# GitHub PoC collector path (mounted in Docker container)
self.github_poc_path = Path("/github_poc_collector/exploits")
# Cache for repository contents
self.repo_cache = {}
self.cache_ttl = 1800 # 30 minutes
def load_github_poc_data(self) -> Dict[str, List[dict]]:
"""Load all PoC data from GitHub PoC collector JSON files"""
poc_data = {}
if not self.github_poc_path.exists():
logger.error(f"GitHub PoC path not found: {self.github_poc_path}")
return poc_data
# Walk through year directories
for year_dir in self.github_poc_path.iterdir():
if year_dir.is_dir():
for json_file in year_dir.glob("*.json"):
try:
cve_id = json_file.stem # CVE-YYYY-NNNN
with open(json_file, 'r') as f:
repos = json.load(f)
# Filter out repositories with no stars or very low quality
filtered_repos = []
for repo in repos:
if isinstance(repo, dict) and repo.get('html_url'):
# Basic quality filtering
stars = repo.get('stargazers_count', 0)
description = repo.get('description', '') or ''
# Skip very low quality repos
if stars > 0 or len(description) > 20:
filtered_repos.append(repo)
if filtered_repos:
poc_data[cve_id] = filtered_repos
except Exception as e:
logger.error(f"Error loading {json_file}: {e}")
logger.info(f"Loaded PoC data for {len(poc_data)} CVEs")
return poc_data
async def _make_github_request(self, session: aiohttp.ClientSession,
url: str, params: dict = None) -> Optional[dict]:
"""Make a rate-limited request to GitHub API"""
try:
# Rate limiting
current_time = asyncio.get_event_loop().time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.rate_limit_delay:
await asyncio.sleep(self.rate_limit_delay - time_since_last)
headers = {
'Accept': 'application/vnd.github.v3+json',
'User-Agent': 'Auto-SIGMA-Rule-Generator/1.0'
}
if self.github_token:
headers['Authorization'] = f'token {self.github_token}'
async with session.get(url, params=params, headers=headers, timeout=30) as response:
self.last_request_time = asyncio.get_event_loop().time()
if response.status == 200:
return await response.json()
elif response.status == 403:
logger.warning(f"Rate limit exceeded for {url}")
# Wait and retry once
await asyncio.sleep(60)
return await self._make_github_request(session, url, params)
else:
logger.warning(f"GitHub API request failed: {response.status} for {url}")
return None
except Exception as e:
logger.error(f"Error making GitHub request to {url}: {e}")
return None
async def get_repository_contents(self, repo_url: str) -> Dict[str, any]:
"""Fetch repository contents from GitHub API"""
# Extract owner/repo from URL
try:
# Parse GitHub URL: https://github.com/owner/repo
parts = repo_url.replace('https://github.com/', '').split('/')
if len(parts) < 2:
return {}
owner, repo = parts[0], parts[1]
repo_key = f"{owner}/{repo}"
# Check cache
if repo_key in self.repo_cache:
cached_data, timestamp = self.repo_cache[repo_key]
if (datetime.now().timestamp() - timestamp) < self.cache_ttl:
return cached_data
async with aiohttp.ClientSession() as session:
# Get repository metadata
repo_url_api = f"{self.base_url}/repos/{owner}/{repo}"
repo_data = await self._make_github_request(session, repo_url_api)
if not repo_data:
return {}
# Get repository contents (files)
contents_url = f"{self.base_url}/repos/{owner}/{repo}/contents"
contents_data = await self._make_github_request(session, contents_url)
# Get README content if available
readme_content = ""
if contents_data:
readme_files = [f for f in contents_data
if f.get('name', '').lower().startswith('readme')]
if readme_files:
readme_file = readme_files[0]
readme_url = readme_file.get('download_url')
if readme_url:
try:
async with session.get(readme_url) as readme_response:
if readme_response.status == 200:
readme_content = await readme_response.text()
except Exception as e:
logger.warning(f"Error fetching README: {e}")
# Extract key files (potential exploit code)
key_files = []
if contents_data:
for file_info in contents_data:
if file_info.get('type') == 'file':
file_name = file_info.get('name', '')
file_size = file_info.get('size', 0)
# Focus on code files that might contain exploits
if (file_name.lower().endswith(('.py', '.sh', '.pl', '.rb', '.js', '.c', '.cpp', '.java', '.go', '.rs', '.php'))
and file_size < 50000): # Skip very large files
try:
file_content = await self._get_file_content(session, file_info.get('download_url'))
if file_content:
key_files.append({
'name': file_name,
'size': file_size,
'content': file_content[:10000] # Truncate very long files
})
except Exception as e:
logger.warning(f"Error fetching file {file_name}: {e}")
result = {
'repo_data': repo_data,
'readme_content': readme_content,
'key_files': key_files,
'fetched_at': datetime.now().isoformat()
}
# Cache the result
self.repo_cache[repo_key] = (result, datetime.now().timestamp())
return result
except Exception as e:
logger.error(f"Error fetching repository contents for {repo_url}: {e}")
return {}
async def _get_file_content(self, session: aiohttp.ClientSession, download_url: str) -> Optional[str]:
"""Fetch individual file content"""
try:
async with session.get(download_url, timeout=15) as response:
if response.status == 200:
# Try to decode as text
try:
content = await response.text()
return content
except:
# If text decoding fails, try binary
content = await response.read()
return content.decode('utf-8', errors='ignore')
return None
except Exception as e:
logger.warning(f"Error fetching file content: {e}")
return None
def analyze_repository_for_indicators(self, repo_data: Dict[str, any]) -> Dict[str, any]:
"""Analyze repository contents for exploit indicators"""
indicators = {
"processes": [],
"files": [],
"network": [],
"registry": [],
"commands": [],
"urls": [],
"techniques": [],
"cve_references": [],
"exploit_techniques": []
}
# Combine all text content for analysis
text_sources = []
# Add README content
if repo_data.get('readme_content'):
text_sources.append(repo_data['readme_content'])
# Add repository description
if repo_data.get('repo_data', {}).get('description'):
text_sources.append(repo_data['repo_data']['description'])
# Add key file contents
for file_info in repo_data.get('key_files', []):
text_sources.append(file_info.get('content', ''))
full_text = " ".join(text_sources).lower()
# Extract CVE references
cve_pattern = r'cve-\d{4}-\d{4,7}'
cve_matches = re.findall(cve_pattern, full_text, re.IGNORECASE)
indicators["cve_references"] = list(set(cve_matches))
# Enhanced process patterns
process_patterns = [
r'\b(cmd\.exe|powershell\.exe|bash|sh|python\.exe|java\.exe|node\.exe)\b',
r'\b(createprocess|shellexecute|system|winexec|execve|fork|spawn)\b',
r'\b(reverse.?shell|bind.?shell|web.?shell|backdoor)\b',
r'\b(mshta\.exe|rundll32\.exe|regsvr32\.exe|wscript\.exe|cscript\.exe)\b',
r'\b(certutil\.exe|bitsadmin\.exe|schtasks\.exe|wmic\.exe)\b'
]
for pattern in process_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["processes"].extend(matches)
# Enhanced file patterns
file_patterns = [
r'\b([a-zA-Z]:\\[^\\\s]+\\[^\\\s]+\.[a-zA-Z0-9]+)\b', # Windows paths
r'\b(/[^/\s]+/[^/\s]+\.[a-zA-Z0-9]+)\b', # Unix paths
r'\b(\w+\.(exe|dll|bat|ps1|py|sh|jar|php|jsp|asp|aspx|bin))\b', # Executable files
r'\b(payload|exploit|shell|backdoor|trojan|malware)\b' # Malicious indicators
]
for pattern in file_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
if matches and isinstance(matches[0], tuple):
indicators["files"].extend([m[0] for m in matches])
else:
indicators["files"].extend(matches)
# Enhanced network patterns
network_patterns = [
r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b', # IP addresses
r'\b(https?://[^\s<>"]+)\b', # URLs
r'\b([a-zA-Z0-9-]+\.[a-zA-Z]{2,})\b', # Domain names
r'\b(port|socket|connect|bind|listen)\s*[=:]\s*(\d+)\b' # Port references
]
for pattern in network_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
if 'http' in pattern:
indicators["urls"].extend(matches)
else:
indicators["network"].extend([m if isinstance(m, str) else m[0] for m in matches])
# Enhanced command patterns
command_patterns = [
r'\b(curl|wget|nc|netcat|ncat|telnet|ssh|scp|rsync)\b',
r'\b(whoami|id|uname|systeminfo|ipconfig|ifconfig|ps|top|netstat)\b',
r'\b(cat|type|more|less|head|tail|find|grep|awk|sed)\b',
r'\b(echo|print|printf|base64|decode|encode)\b',
r'\b(invoke|iex|downloadstring|powershell|cmd)\b',
r'\b(net\s+user|net\s+localgroup|net\s+share)\b',
r'\b(sc\s+create|sc\s+start|sc\s+stop|service)\b'
]
for pattern in command_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["commands"].extend(matches)
# Registry patterns (Windows)
registry_patterns = [
r'\b(HKEY_[A-Z_]+)\b',
r'\b(HKLM|HKCU|HKCR|HKU|HKCC)\b',
r'\b(reg\s+add|reg\s+query|reg\s+delete|regedit)\b',
r'\b(SOFTWARE\\\\[^\\\s]+)\b',
r'\b(SYSTEM\\\\[^\\\s]+)\b'
]
for pattern in registry_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["registry"].extend(matches)
# MITRE ATT&CK technique indicators
technique_patterns = [
r'\b(privilege.?escalation|lateral.?movement|persistence|evasion)\b',
r'\b(injection|hijack|bypass|overflow|buffer.?overflow)\b',
r'\b(credential.?dump|password.?spray|brute.?force)\b',
r'\b(remote.?code.?execution|arbitrary.?code|code.?injection)\b',
r'\b(dll.?injection|process.?hollow|process.?injection)\b'
]
for pattern in technique_patterns:
matches = re.findall(pattern, full_text, re.IGNORECASE)
indicators["techniques"].extend(matches)
# Clean up and deduplicate all indicators
for key in indicators:
# Remove empty strings and duplicates
indicators[key] = list(set([
item.strip() for item in indicators[key]
if item and len(str(item).strip()) > 2
]))
# Limit to reasonable number of indicators
indicators[key] = indicators[key][:20]
return indicators
def calculate_quality_score(self, repo_info: dict, repo_contents: Dict[str, any]) -> Dict[str, any]:
"""Calculate quality score for a repository"""
quality_score = 0
factors = {}
# Star count factor (0-30 points)
stars = repo_info.get('stargazers_count', 0)
star_score = min(stars * 3, 30) # 3 points per star, max 30
quality_score += star_score
factors["star_score"] = star_score
# Fork count factor (0-20 points)
forks = repo_info.get('forks_count', 0)
fork_score = min(forks * 2, 20) # 2 points per fork, max 20
quality_score += fork_score
factors["fork_score"] = fork_score
# Recency factor (0-20 points)
try:
created_at = datetime.fromisoformat(repo_info.get('created_at', '').replace('Z', '+00:00'))
days_old = (datetime.now(created_at.tzinfo) - created_at).days
recency_score = max(20 - (days_old // 30), 0) # Lose 1 point per month
quality_score += recency_score
factors["recency_score"] = recency_score
except:
factors["recency_score"] = 0
# Description quality factor (0-15 points)
description = repo_info.get('description', '') or ''
desc_score = min(len(description) // 10, 15) # 1 point per 10 chars, max 15
quality_score += desc_score
factors["description_score"] = desc_score
# README quality factor (0-15 points)
readme_content = repo_contents.get('readme_content', '')
readme_score = min(len(readme_content) // 50, 15) # 1 point per 50 chars, max 15
quality_score += readme_score
factors["readme_score"] = readme_score
return {
"quality_score": quality_score,
"factors": factors,
"quality_tier": self._get_quality_tier(quality_score)
}
def _get_quality_tier(self, score: int) -> str:
"""Get quality tier based on score"""
if score >= 80:
return "excellent"
elif score >= 60:
return "good"
elif score >= 40:
return "fair"
elif score >= 20:
return "poor"
else:
return "very_poor"
async def sync_cve_pocs(self, cve_id: str) -> dict:
"""Synchronize PoC data for a specific CVE using GitHub PoC data"""
from models import CVE, SigmaRule
# Get existing CVE
cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.warning(f"CVE {cve_id} not found in database")
return {"error": "CVE not found"}
# Load GitHub PoC data
github_poc_data = self.load_github_poc_data()
if cve_id not in github_poc_data:
logger.info(f"No PoCs found for {cve_id} in GitHub PoC data")
return {"cve_id": cve_id, "pocs_found": 0}
repos = github_poc_data[cve_id]
# Process each repository
poc_data = []
github_repos = []
total_quality_score = 0
for repo_info in repos:
try:
# Fetch repository contents
repo_contents = await self.get_repository_contents(repo_info['html_url'])
# Analyze for indicators
indicators = self.analyze_repository_for_indicators(repo_contents)
# Calculate quality score
quality_analysis = self.calculate_quality_score(repo_info, repo_contents)
poc_entry = {
"id": repo_info.get('name', ''),
"name": repo_info.get('name', ''),
"full_name": repo_info.get('full_name', ''),
"html_url": repo_info.get('html_url', ''),
"description": repo_info.get('description', ''),
"stargazers_count": repo_info.get('stargazers_count', 0),
"forks_count": repo_info.get('forks_count', 0),
"created_at": repo_info.get('created_at', ''),
"quality_analysis": quality_analysis,
"exploit_indicators": indicators,
"source": "mcdevitt_github"
}
poc_data.append(poc_entry)
github_repos.append(repo_info.get('html_url', ''))
total_quality_score += quality_analysis["quality_score"]
except Exception as e:
logger.error(f"Error processing repo {repo_info.get('html_url', '')}: {e}")
continue
# Update CVE with PoC data
cve.poc_count = len(poc_data)
cve.poc_data = poc_data
cve.updated_at = datetime.utcnow()
# Update or create SIGMA rule with enhanced PoC data
sigma_rule = self.db_session.query(SigmaRule).filter(
SigmaRule.cve_id == cve_id
).first()
if sigma_rule:
sigma_rule.poc_source = 'github_poc'
sigma_rule.poc_quality_score = total_quality_score // len(poc_data) if poc_data else 0
sigma_rule.nomi_sec_data = {
"total_pocs": len(poc_data),
"average_quality": total_quality_score // len(poc_data) if poc_data else 0,
"best_poc": max(poc_data, key=lambda x: x["quality_analysis"]["quality_score"]) if poc_data else None,
"total_stars": sum(p["stargazers_count"] for p in poc_data),
"source": "github_poc"
}
sigma_rule.github_repos = github_repos
sigma_rule.updated_at = datetime.utcnow()
# Extract best exploit indicators
best_indicators = {}
for poc in poc_data:
for key, values in poc["exploit_indicators"].items():
if key not in best_indicators:
best_indicators[key] = []
best_indicators[key].extend(values)
# Deduplicate and store
for key in best_indicators:
best_indicators[key] = list(set(best_indicators[key]))
sigma_rule.exploit_indicators = json.dumps(best_indicators)
self.db_session.commit()
logger.info(f"Synchronized {len(poc_data)} PoCs for {cve_id}")
return {
"cve_id": cve_id,
"pocs_found": len(poc_data),
"total_quality_score": total_quality_score,
"average_quality": total_quality_score // len(poc_data) if poc_data else 0,
"github_repos": github_repos,
"source": "github_poc"
}
async def bulk_sync_all_cves(self, batch_size: int = 50) -> dict:
"""Bulk synchronize all CVEs with GitHub PoC data"""
from models import CVE, BulkProcessingJob
# Load all GitHub PoC data first
github_poc_data = self.load_github_poc_data()
if not github_poc_data:
return {"error": "No GitHub PoC data found"}
# Create bulk processing job
job = BulkProcessingJob(
job_type='github_poc_sync',
status='running',
started_at=datetime.utcnow(),
total_items=len(github_poc_data),
job_metadata={'batch_size': batch_size}
)
self.db_session.add(job)
self.db_session.commit()
total_processed = 0
total_found = 0
results = []
try:
# Process each CVE that has PoC data
cve_ids = list(github_poc_data.keys())
for i in range(0, len(cve_ids), batch_size):
batch = cve_ids[i:i + batch_size]
for cve_id in batch:
try:
result = await self.sync_cve_pocs(cve_id)
total_processed += 1
if result.get("pocs_found", 0) > 0:
total_found += result["pocs_found"]
results.append(result)
job.processed_items += 1
# Small delay to avoid overwhelming GitHub API
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error syncing PoCs for {cve_id}: {e}")
job.failed_items += 1
# Commit after each batch
self.db_session.commit()
logger.info(f"Processed batch {i//batch_size + 1}/{(len(cve_ids) + batch_size - 1)//batch_size}")
# Update job status
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_pocs_found': total_found,
'cves_with_pocs': len(results)
})
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Bulk McDevitt sync job failed: {e}")
finally:
self.db_session.commit()
return {
'job_id': str(job.id),
'status': job.status,
'total_processed': total_processed,
'total_pocs_found': total_found,
'cves_with_pocs': len(results)
}