script to clear old sigma rules and starting to tweak system prompt to send to llm for rule generation
This commit is contained in:
parent
d17f961b9d
commit
d38edff1cd
6 changed files with 506 additions and 105 deletions
58
backend/delete_sigma_rules.py
Normal file
58
backend/delete_sigma_rules.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to delete all SIGMA rules from the database
|
||||
This will clear existing rules so they can be regenerated with the improved LLM client
|
||||
"""
|
||||
|
||||
from main import SigmaRule, SessionLocal
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def delete_all_sigma_rules():
|
||||
"""Delete all SIGMA rules from the database"""
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Count existing rules
|
||||
total_rules = db.query(SigmaRule).count()
|
||||
logger.info(f"Found {total_rules} SIGMA rules in database")
|
||||
|
||||
if total_rules == 0:
|
||||
logger.info("No SIGMA rules to delete")
|
||||
return 0
|
||||
|
||||
# Delete all SIGMA rules
|
||||
logger.info("Deleting all SIGMA rules...")
|
||||
deleted_count = db.query(SigmaRule).delete()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"✅ Successfully deleted {deleted_count} SIGMA rules")
|
||||
|
||||
# Verify deletion
|
||||
remaining_rules = db.query(SigmaRule).count()
|
||||
logger.info(f"Remaining rules in database: {remaining_rules}")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting SIGMA rules: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🗑️ Deleting all SIGMA rules from database...")
|
||||
print("This will allow regeneration with the improved LLM client.")
|
||||
|
||||
deleted_count = delete_all_sigma_rules()
|
||||
|
||||
if deleted_count > 0:
|
||||
print(f"\n🎉 Successfully deleted {deleted_count} SIGMA rules!")
|
||||
print("You can now regenerate them with the fixed LLM prompts.")
|
||||
else:
|
||||
print("\n✅ No SIGMA rules were found to delete.")
|
|
@ -38,12 +38,18 @@ class EnhancedSigmaGenerator:
|
|||
# Try LLM-enhanced generation first if enabled and available
|
||||
rule_content = None
|
||||
generation_method = "template"
|
||||
template = None
|
||||
|
||||
if use_llm and self.llm_client.is_available() and best_poc:
|
||||
logger.info(f"Attempting LLM-enhanced rule generation for {cve.cve_id} using {self.llm_client.provider}")
|
||||
rule_content = await self._generate_llm_enhanced_rule(cve, best_poc, poc_data)
|
||||
if rule_content:
|
||||
generation_method = f"llm_{self.llm_client.provider}"
|
||||
# Create a dummy template object for LLM-generated rules
|
||||
class LLMTemplate:
|
||||
def __init__(self, provider_name):
|
||||
self.template_name = f"LLM Generated ({provider_name})"
|
||||
template = LLMTemplate(self.llm_client.provider)
|
||||
|
||||
# Fallback to template-based generation
|
||||
if not rule_content:
|
||||
|
@ -107,7 +113,7 @@ class EnhancedSigmaGenerator:
|
|||
return {
|
||||
'success': True,
|
||||
'cve_id': cve.cve_id,
|
||||
'template': template.template_name,
|
||||
'template': template.template_name if template else 'Unknown',
|
||||
'confidence_level': confidence_level,
|
||||
'poc_count': len(poc_data),
|
||||
'quality_score': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0
|
||||
|
|
|
@ -80,18 +80,24 @@ class JobExecutors:
|
|||
|
||||
@staticmethod
|
||||
async def nomi_sec_sync(db_session: Session, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute nomi-sec PoC sync job"""
|
||||
"""Execute optimized nomi-sec PoC sync job"""
|
||||
try:
|
||||
from nomi_sec_client import NomiSecClient
|
||||
|
||||
client = NomiSecClient(db_session)
|
||||
|
||||
# Extract parameters
|
||||
batch_size = parameters.get('batch_size', 50)
|
||||
# Extract parameters with optimized defaults
|
||||
batch_size = parameters.get('batch_size', 100)
|
||||
max_cves = parameters.get('max_cves', 1000)
|
||||
force_resync = parameters.get('force_resync', False)
|
||||
|
||||
logger.info(f"Starting nomi-sec sync - batch_size: {batch_size}")
|
||||
logger.info(f"Starting optimized nomi-sec sync - batch_size: {batch_size}, max_cves: {max_cves}")
|
||||
|
||||
result = await client.bulk_sync_poc_data(batch_size=batch_size)
|
||||
result = await client.bulk_sync_poc_data(
|
||||
batch_size=batch_size,
|
||||
max_cves=max_cves,
|
||||
force_resync=force_resync
|
||||
)
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
|
|
|
@ -173,17 +173,30 @@ class LLMClient:
|
|||
# Create the chain
|
||||
chain = prompt | self.llm | self.output_parser
|
||||
|
||||
# Generate the response
|
||||
response = await chain.ainvoke({
|
||||
# Debug: Log what we're sending to the LLM
|
||||
input_data = {
|
||||
"cve_id": cve_id,
|
||||
"poc_content": poc_content[:4000], # Truncate if too long
|
||||
"cve_description": cve_description,
|
||||
"existing_rule": existing_rule or "None"
|
||||
})
|
||||
}
|
||||
logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}")
|
||||
|
||||
# Generate the response
|
||||
response = await chain.ainvoke(input_data)
|
||||
|
||||
# Debug: Log raw LLM response
|
||||
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
||||
|
||||
# Extract the SIGMA rule from response
|
||||
sigma_rule = self._extract_sigma_rule(response)
|
||||
|
||||
# Post-process to ensure clean YAML
|
||||
sigma_rule = self._post_process_sigma_rule(sigma_rule)
|
||||
|
||||
# Debug: Log final processed rule
|
||||
logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...")
|
||||
|
||||
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
|
||||
return sigma_rule
|
||||
|
||||
|
@ -198,61 +211,103 @@ class LLMClient:
|
|||
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
|
||||
"""Build the prompt template for SIGMA rule generation."""
|
||||
|
||||
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules.
|
||||
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
|
||||
|
||||
**Your Task:**
|
||||
1. Analyze the exploit code to identify:
|
||||
- Process execution patterns
|
||||
- File system activities
|
||||
- Network connections
|
||||
- Registry modifications
|
||||
- Command line arguments
|
||||
- Suspicious behaviors
|
||||
**CRITICAL: You must follow the exact SIGMA specification format:**
|
||||
|
||||
2. Create a SIGMA rule that:
|
||||
- Follows proper SIGMA syntax (YAML format)
|
||||
- Includes appropriate detection logic
|
||||
- Has relevant metadata (title, description, author, date, references)
|
||||
- Uses correct field names for the target log source
|
||||
- Includes proper condition logic
|
||||
- Maps to relevant MITRE ATT&CK techniques when applicable
|
||||
1. **YAML Structure Requirements:**
|
||||
- Use UTF-8 encoding with LF line breaks
|
||||
- Indent with 4 spaces (no tabs)
|
||||
- Use lowercase keys only
|
||||
- Use single quotes for string values
|
||||
- No quotes for numeric values
|
||||
- Follow proper YAML syntax
|
||||
|
||||
3. Focus on detection patterns that would catch this specific exploit in action
|
||||
2. **MANDATORY Fields (must include):**
|
||||
- title: Brief description (max 256 chars)
|
||||
- logsource: Log data source specification
|
||||
- detection: Search identifiers and conditions
|
||||
- condition: How detection elements combine
|
||||
|
||||
**Important Requirements:**
|
||||
- Output ONLY the SIGMA rule in valid YAML format
|
||||
- Do not include explanations or comments outside the YAML
|
||||
- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition
|
||||
- Make the rule specific enough to detect the exploit but not too narrow to miss variants
|
||||
- Include relevant tags and MITRE ATT&CK technique mappings"""
|
||||
3. **RECOMMENDED Fields:**
|
||||
- id: Unique UUID
|
||||
- status: 'experimental' (for new rules)
|
||||
- description: Detailed explanation
|
||||
- author: 'AI Generated'
|
||||
- date: Current date (YYYY/MM/DD)
|
||||
- references: Array with CVE link
|
||||
- tags: MITRE ATT&CK techniques
|
||||
|
||||
4. **Detection Structure:**
|
||||
- Use selection blocks (selection, selection1, etc.)
|
||||
- Condition references these selections
|
||||
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
||||
- Support wildcards (*) and value lists
|
||||
|
||||
**ABSOLUTE REQUIREMENTS:**
|
||||
- Output ONLY valid YAML
|
||||
- NO explanatory text before or after
|
||||
- NO comments or instructions
|
||||
- NO markdown formatting or code blocks
|
||||
- NEVER repeat the input prompt or template
|
||||
- NEVER include variables like {cve_id} or {poc_content}
|
||||
- NO "Human:", "CVE ID:", "Description:" headers
|
||||
- NO "Analyze this" or "Output only" text
|
||||
- Start IMMEDIATELY with 'title:'
|
||||
- End with the last YAML line only
|
||||
- Ensure perfect YAML syntax
|
||||
|
||||
**STRUCTURE REQUIREMENTS:**
|
||||
- title: Descriptive title that MUST include the exact CVE ID provided by the user
|
||||
- id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012')
|
||||
- status: experimental
|
||||
- description: Specific description based on CVE and PoC analysis
|
||||
- author: 'AI Generated'
|
||||
- date: Current date (2025/01/11)
|
||||
- references: Include the EXACT CVE URL with the CVE ID provided by the user
|
||||
- tags: Relevant MITRE ATT&CK techniques based on PoC analysis
|
||||
- logsource: Appropriate category based on exploit type
|
||||
- detection: Specific indicators from PoC analysis (NOT generic examples)
|
||||
- condition: Logic connecting the detection selections
|
||||
|
||||
**CRITICAL RULES:**
|
||||
1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID
|
||||
2. Analyze the provided CVE and PoC content to create SPECIFIC detection patterns
|
||||
3. DO NOT hallucinate or invent CVE IDs from your training data
|
||||
4. Use the CVE ID exactly as provided in the title and references"""
|
||||
|
||||
if existing_rule:
|
||||
user_template = """**CVE Information:**
|
||||
- CVE ID: {cve_id}
|
||||
- Description: {cve_description}
|
||||
user_template = """CVE ID: {cve_id}
|
||||
CVE Description: {cve_description}
|
||||
|
||||
**Proof-of-Concept Code:**
|
||||
```
|
||||
PoC Code:
|
||||
{poc_content}
|
||||
```
|
||||
|
||||
**Existing SIGMA Rule (to enhance):**
|
||||
```yaml
|
||||
Existing SIGMA Rule:
|
||||
{existing_rule}
|
||||
```
|
||||
|
||||
Please enhance the existing rule with insights from the PoC code analysis."""
|
||||
Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'."""
|
||||
else:
|
||||
user_template = """**CVE Information:**
|
||||
- CVE ID: {cve_id}
|
||||
- Description: {cve_description}
|
||||
user_template = """CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE:
|
||||
|
||||
**Proof-of-Concept Code:**
|
||||
```
|
||||
**MANDATORY CVE ID TO USE: {cve_id}**
|
||||
**CVE Description: {cve_description}**
|
||||
|
||||
**Proof-of-Concept Code Analysis:**
|
||||
{poc_content}
|
||||
```
|
||||
|
||||
Please create a new SIGMA rule based on the PoC code analysis."""
|
||||
**CRITICAL REQUIREMENTS:**
|
||||
1. Use EXACTLY this CVE ID in the title: {cve_id}
|
||||
2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{cve_id}
|
||||
3. Analyze the CVE description to understand the vulnerability type
|
||||
4. Extract specific indicators from the PoC code (files, processes, commands, network patterns)
|
||||
5. Create detection logic based on the actual exploit behavior
|
||||
6. Use relevant logsource category (process_creation, file_event, network_connection, etc.)
|
||||
7. Include appropriate MITRE ATT&CK tags based on the exploit techniques
|
||||
|
||||
**IMPORTANT: You MUST use the exact CVE ID "{cve_id}" - do NOT generate a different CVE ID!**
|
||||
|
||||
Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {cve_id}."""
|
||||
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=system_message),
|
||||
|
@ -260,28 +315,116 @@ Please create a new SIGMA rule based on the PoC code analysis."""
|
|||
])
|
||||
|
||||
def _extract_sigma_rule(self, response_text: str) -> str:
|
||||
"""Extract SIGMA rule YAML from LLM response."""
|
||||
# Look for YAML content in the response
|
||||
"""Extract and clean SIGMA rule YAML from LLM response."""
|
||||
lines = response_text.split('\n')
|
||||
yaml_lines = []
|
||||
in_yaml = False
|
||||
in_yaml_block = False
|
||||
found_title = False
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith('```yaml') or line.strip().startswith('```'):
|
||||
in_yaml = True
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip code block markers
|
||||
if stripped.startswith('```'):
|
||||
if stripped.startswith('```yaml'):
|
||||
in_yaml_block = True
|
||||
elif stripped == '```' and in_yaml_block:
|
||||
break
|
||||
continue
|
||||
elif line.strip() == '```' and in_yaml:
|
||||
break
|
||||
elif in_yaml or line.strip().startswith('title:'):
|
||||
|
||||
# Skip obvious non-YAML content
|
||||
if not in_yaml_block and not found_title:
|
||||
if not stripped.startswith('title:'):
|
||||
# Skip explanatory text and prompt artifacts
|
||||
skip_phrases = [
|
||||
'please note', 'this rule', 'you should', 'analysis:',
|
||||
'explanation:', 'based on', 'the following', 'here is',
|
||||
'note that', 'important:', 'remember', 'this is a',
|
||||
'make sure to', 'you can modify', 'adjust the',
|
||||
'human:', 'cve id:', 'cve description:', 'poc code:',
|
||||
'exploit code:', 'analyze this', 'create a', 'output only'
|
||||
]
|
||||
if any(phrase in stripped.lower() for phrase in skip_phrases):
|
||||
continue
|
||||
|
||||
# Skip template variables and prompt artifacts
|
||||
if '{' in stripped and '}' in stripped:
|
||||
continue
|
||||
|
||||
# Skip lines that are clearly not YAML structure
|
||||
if stripped and not ':' in stripped and len(stripped) > 20:
|
||||
continue
|
||||
|
||||
# Start collecting when we find title or are in YAML block
|
||||
if stripped.startswith('title:') or in_yaml_block:
|
||||
found_title = True
|
||||
in_yaml_block = True
|
||||
|
||||
# Skip explanatory comments
|
||||
if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()):
|
||||
continue
|
||||
|
||||
yaml_lines.append(line)
|
||||
in_yaml = True
|
||||
|
||||
# Stop if we encounter obvious non-YAML after starting
|
||||
elif found_title:
|
||||
# Stop at explanatory text after the rule
|
||||
stop_phrases = [
|
||||
'please note', 'this rule should', 'make sure to',
|
||||
'you can modify', 'adjust the', 'also, this is',
|
||||
'based on the analysis', 'the rule above'
|
||||
]
|
||||
if any(phrase in stripped.lower() for phrase in stop_phrases):
|
||||
break
|
||||
|
||||
# Stop at lines without colons that aren't indented (likely explanations)
|
||||
if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped:
|
||||
break
|
||||
|
||||
if not yaml_lines:
|
||||
# If no YAML block found, return the whole response
|
||||
return response_text.strip()
|
||||
# Fallback: look for any line with YAML-like structure
|
||||
for line in lines:
|
||||
if ':' in line and not line.strip().startswith('#'):
|
||||
yaml_lines.append(line)
|
||||
|
||||
return '\n'.join(yaml_lines).strip()
|
||||
|
||||
def _post_process_sigma_rule(self, rule_content: str) -> str:
|
||||
"""Post-process SIGMA rule to ensure clean YAML format."""
|
||||
lines = rule_content.split('\n')
|
||||
cleaned_lines = []
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip obvious non-YAML content and prompt artifacts
|
||||
if any(phrase in stripped.lower() for phrase in [
|
||||
'please note', 'you should replace', 'this is a proof-of-concept',
|
||||
'please make sure', 'note that', 'important:', 'remember to',
|
||||
'analysis shows', 'based on the', 'the rule above', 'this rule',
|
||||
'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:',
|
||||
'analyze this', 'create a', 'output only', 'generate a'
|
||||
]):
|
||||
continue
|
||||
|
||||
# Skip template variables
|
||||
if '{' in stripped and '}' in stripped:
|
||||
continue
|
||||
|
||||
# Skip lines that look like explanations
|
||||
if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '):
|
||||
# This might be explanatory text, skip it
|
||||
if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']):
|
||||
continue
|
||||
|
||||
# Skip empty explanatory sections
|
||||
if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']:
|
||||
continue
|
||||
|
||||
cleaned_lines.append(line)
|
||||
|
||||
return '\n'.join(cleaned_lines).strip()
|
||||
|
||||
async def enhance_existing_rule(self,
|
||||
existing_rule: str,
|
||||
poc_content: str,
|
||||
|
@ -337,6 +480,7 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|||
})
|
||||
|
||||
enhanced_rule = self._extract_sigma_rule(response)
|
||||
enhanced_rule = self._post_process_sigma_rule(enhanced_rule)
|
||||
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
|
||||
return enhanced_rule
|
||||
|
||||
|
@ -345,33 +489,57 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|||
return None
|
||||
|
||||
def validate_sigma_rule(self, rule_content: str) -> bool:
|
||||
"""Validate that the generated rule is syntactically correct SIGMA."""
|
||||
"""Validate that the generated rule follows SIGMA specification."""
|
||||
try:
|
||||
# Parse as YAML
|
||||
parsed = yaml.safe_load(rule_content)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['title', 'id', 'description', 'logsource', 'detection']
|
||||
for field in required_fields:
|
||||
if not isinstance(parsed, dict):
|
||||
logger.warning("Rule must be a YAML dictionary")
|
||||
return False
|
||||
|
||||
# Check MANDATORY fields per SIGMA spec
|
||||
mandatory_fields = ['title', 'logsource', 'detection']
|
||||
for field in mandatory_fields:
|
||||
if field not in parsed:
|
||||
logger.warning(f"Missing required field: {field}")
|
||||
logger.warning(f"Missing mandatory field: {field}")
|
||||
return False
|
||||
|
||||
# Check detection structure
|
||||
detection = parsed.get('detection', {})
|
||||
if not isinstance(detection, dict):
|
||||
logger.warning("Detection field must be a dictionary")
|
||||
# Validate title
|
||||
title = parsed.get('title', '')
|
||||
if not isinstance(title, str) or len(title) > 256:
|
||||
logger.warning("Title must be string ≤256 characters")
|
||||
return False
|
||||
|
||||
# Should have at least one selection and a condition
|
||||
if 'condition' not in detection:
|
||||
logger.warning("Detection must have a condition")
|
||||
return False
|
||||
|
||||
# Check logsource structure
|
||||
# Validate logsource structure
|
||||
logsource = parsed.get('logsource', {})
|
||||
if not isinstance(logsource, dict):
|
||||
logger.warning("Logsource field must be a dictionary")
|
||||
logger.warning("Logsource must be a dictionary")
|
||||
return False
|
||||
|
||||
# Validate detection structure
|
||||
detection = parsed.get('detection', {})
|
||||
if not isinstance(detection, dict):
|
||||
logger.warning("Detection must be a dictionary")
|
||||
return False
|
||||
|
||||
# Check for condition (can be in detection or at root)
|
||||
has_condition = 'condition' in detection or 'condition' in parsed
|
||||
if not has_condition:
|
||||
logger.warning("Missing condition field")
|
||||
return False
|
||||
|
||||
# Check for at least one selection
|
||||
selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter']
|
||||
for key in detection.keys() if key != 'condition')
|
||||
if not selection_found:
|
||||
logger.warning("Detection must have at least one selection")
|
||||
return False
|
||||
|
||||
# Validate status if present
|
||||
status = parsed.get('status')
|
||||
if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']:
|
||||
logger.warning(f"Invalid status: {status}")
|
||||
return False
|
||||
|
||||
logger.info("SIGMA rule validation passed")
|
||||
|
|
|
@ -26,9 +26,10 @@ class NomiSecClient:
|
|||
self.base_url = "https://poc-in-github.motikan2010.net/api/v1"
|
||||
self.rss_url = "https://poc-in-github.motikan2010.net/rss"
|
||||
|
||||
# Rate limiting
|
||||
self.rate_limit_delay = 1.0 # 1 second between requests
|
||||
# Optimized rate limiting
|
||||
self.rate_limit_delay = 0.2 # 200ms between requests (5 requests/second)
|
||||
self.last_request_time = 0
|
||||
self.concurrent_requests = 3 # Allow concurrent requests
|
||||
|
||||
# Cache for recently fetched data
|
||||
self.cache = {}
|
||||
|
@ -36,50 +37,70 @@ class NomiSecClient:
|
|||
|
||||
async def _make_request(self, session: aiohttp.ClientSession,
|
||||
url: str, params: dict = None) -> Optional[dict]:
|
||||
"""Make a rate-limited request to the API"""
|
||||
"""Make an optimized rate-limited request to the API"""
|
||||
try:
|
||||
# Rate limiting
|
||||
# Optimized rate limiting
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self.last_request_time
|
||||
if time_since_last < self.rate_limit_delay:
|
||||
await asyncio.sleep(self.rate_limit_delay - time_since_last)
|
||||
|
||||
async with session.get(url, params=params, timeout=30) as response:
|
||||
async with session.get(url, params=params, timeout=10) as response:
|
||||
self.last_request_time = time.time()
|
||||
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif response.status == 429: # Rate limited
|
||||
logger.warning(f"Rate limited, retrying after delay")
|
||||
await asyncio.sleep(2.0)
|
||||
return await self._make_request(session, url, params)
|
||||
else:
|
||||
logger.warning(f"API request failed: {response.status} for {url}")
|
||||
return None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Request timeout for {url}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error making request to {url}: {e}")
|
||||
return None
|
||||
|
||||
async def get_pocs_for_cve(self, cve_id: str) -> List[dict]:
|
||||
"""Get all PoC repositories for a specific CVE"""
|
||||
async def get_pocs_for_cve(self, cve_id: str, session: aiohttp.ClientSession = None) -> List[dict]:
|
||||
"""Get all PoC repositories for a specific CVE with optimized session reuse"""
|
||||
cache_key = f"cve_{cve_id}"
|
||||
|
||||
# Check cache
|
||||
if cache_key in self.cache:
|
||||
cached_data, timestamp = self.cache[cache_key]
|
||||
if time.time() - timestamp < self.cache_ttl:
|
||||
logger.debug(f"Cache hit for {cve_id}")
|
||||
return cached_data
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Use provided session or create new one
|
||||
if session:
|
||||
params = {"cve_id": cve_id}
|
||||
data = await self._make_request(session, self.base_url, params)
|
||||
|
||||
if data and "pocs" in data:
|
||||
pocs = data["pocs"]
|
||||
# Cache the result
|
||||
self.cache[cache_key] = (pocs, time.time())
|
||||
logger.info(f"Found {len(pocs)} PoCs for {cve_id}")
|
||||
return pocs
|
||||
else:
|
||||
logger.info(f"No PoCs found for {cve_id}")
|
||||
return []
|
||||
else:
|
||||
# Optimized connector with connection pooling
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=100,
|
||||
limit_per_host=10,
|
||||
ttl_dns_cache=300,
|
||||
use_dns_cache=True
|
||||
)
|
||||
async with aiohttp.ClientSession(connector=connector) as new_session:
|
||||
params = {"cve_id": cve_id}
|
||||
data = await self._make_request(new_session, self.base_url, params)
|
||||
|
||||
if data and "pocs" in data:
|
||||
pocs = data["pocs"]
|
||||
# Cache the result
|
||||
self.cache[cache_key] = (pocs, time.time())
|
||||
logger.debug(f"Found {len(pocs)} PoCs for {cve_id}")
|
||||
return pocs
|
||||
else:
|
||||
logger.debug(f"No PoCs found for {cve_id}")
|
||||
return []
|
||||
|
||||
async def get_recent_pocs(self, limit: int = 100) -> List[dict]:
|
||||
"""Get recent PoCs from the API"""
|
||||
|
@ -291,8 +312,8 @@ class NomiSecClient:
|
|||
|
||||
return indicators
|
||||
|
||||
async def sync_cve_pocs(self, cve_id: str) -> dict:
|
||||
"""Synchronize PoC data for a specific CVE"""
|
||||
async def sync_cve_pocs(self, cve_id: str, session: aiohttp.ClientSession = None) -> dict:
|
||||
"""Synchronize PoC data for a specific CVE with session reuse"""
|
||||
from main import CVE, SigmaRule
|
||||
|
||||
# Get existing CVE
|
||||
|
@ -301,8 +322,8 @@ class NomiSecClient:
|
|||
logger.warning(f"CVE {cve_id} not found in database")
|
||||
return {"error": "CVE not found"}
|
||||
|
||||
# Fetch PoCs from nomi-sec API
|
||||
pocs = await self.get_pocs_for_cve(cve_id)
|
||||
# Fetch PoCs from nomi-sec API with session reuse
|
||||
pocs = await self.get_pocs_for_cve(cve_id, session)
|
||||
|
||||
if not pocs:
|
||||
logger.info(f"No PoCs found for {cve_id}")
|
||||
|
@ -438,8 +459,8 @@ class NomiSecClient:
|
|||
|
||||
job.processed_items += 1
|
||||
|
||||
# Small delay to avoid overwhelming the API
|
||||
await asyncio.sleep(0.5)
|
||||
# Minimal delay for faster processing
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing PoCs for {cve.cve_id}: {e}")
|
||||
|
@ -481,6 +502,146 @@ class NomiSecClient:
|
|||
'cves_with_pocs': len(results)
|
||||
}
|
||||
|
||||
async def bulk_sync_poc_data(self, batch_size: int = 50, max_cves: int = None,
|
||||
force_resync: bool = False) -> dict:
|
||||
"""Optimized bulk synchronization of PoC data with performance improvements"""
|
||||
from main import CVE, SigmaRule, BulkProcessingJob
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Create job tracking
|
||||
job = BulkProcessingJob(
|
||||
job_type='nomi_sec_sync',
|
||||
status='running',
|
||||
started_at=datetime.utcnow(),
|
||||
job_metadata={'batch_size': batch_size, 'max_cves': max_cves, 'force_resync': force_resync}
|
||||
)
|
||||
self.db_session.add(job)
|
||||
self.db_session.commit()
|
||||
|
||||
try:
|
||||
# Get CVEs that need PoC sync - optimized query
|
||||
query = self.db_session.query(CVE)
|
||||
|
||||
if not force_resync:
|
||||
# Skip CVEs that were recently synced or already have nomi-sec data
|
||||
recent_cutoff = datetime.utcnow() - timedelta(days=7)
|
||||
query = query.filter(
|
||||
or_(
|
||||
CVE.poc_data.is_(None),
|
||||
and_(
|
||||
CVE.updated_at < recent_cutoff,
|
||||
CVE.poc_count == 0
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Prioritize recent CVEs and high CVSS scores
|
||||
query = query.order_by(
|
||||
CVE.published_date.desc(),
|
||||
CVE.cvss_score.desc().nullslast()
|
||||
)
|
||||
|
||||
if max_cves:
|
||||
query = query.limit(max_cves)
|
||||
|
||||
cves = query.all()
|
||||
job.total_items = len(cves)
|
||||
self.db_session.commit()
|
||||
|
||||
logger.info(f"Starting optimized nomi-sec sync for {len(cves)} CVEs")
|
||||
|
||||
total_processed = 0
|
||||
total_found = 0
|
||||
concurrent_semaphore = asyncio.Semaphore(self.concurrent_requests)
|
||||
|
||||
# Create shared session with optimized settings
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=50,
|
||||
limit_per_host=10,
|
||||
ttl_dns_cache=300,
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=30
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(connector=connector) as shared_session:
|
||||
async def process_cve_batch(cve_batch):
|
||||
"""Process a batch of CVEs concurrently with shared session"""
|
||||
async def process_single_cve(cve):
|
||||
async with concurrent_semaphore:
|
||||
try:
|
||||
result = await self.sync_cve_pocs(cve.cve_id, shared_session)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing {cve.cve_id}: {e}")
|
||||
return {'error': str(e), 'cve_id': cve.cve_id}
|
||||
|
||||
# Process batch concurrently
|
||||
tasks = [process_single_cve(cve) for cve in cve_batch]
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return batch_results
|
||||
|
||||
# Process in optimized batches
|
||||
for i in range(0, len(cves), batch_size):
|
||||
batch = cves[i:i + batch_size]
|
||||
|
||||
# Process batch concurrently
|
||||
batch_results = await process_cve_batch(batch)
|
||||
|
||||
for result in batch_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Exception in batch processing: {result}")
|
||||
job.failed_items += 1
|
||||
elif isinstance(result, dict) and 'error' not in result:
|
||||
total_processed += 1
|
||||
if result.get('pocs_found', 0) > 0:
|
||||
total_found += result['pocs_found']
|
||||
job.processed_items += 1
|
||||
else:
|
||||
job.failed_items += 1
|
||||
|
||||
# Commit progress every batch
|
||||
self.db_session.commit()
|
||||
logger.info(f"Processed batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}, "
|
||||
f"found {total_found} PoCs so far")
|
||||
|
||||
# Small delay between batches
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Update job completion
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.job_metadata.update({
|
||||
'total_processed': total_processed,
|
||||
'total_pocs_found': total_found,
|
||||
'processing_time_seconds': (job.completed_at - job.started_at).total_seconds()
|
||||
})
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
logger.info(f"Nomi-sec sync completed: {total_processed} CVEs processed, {total_found} PoCs found")
|
||||
|
||||
return {
|
||||
'job_id': str(job.id),
|
||||
'status': 'completed',
|
||||
'total_processed': total_processed,
|
||||
'total_pocs_found': total_found,
|
||||
'processing_time': (job.completed_at - job.started_at).total_seconds()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
self.db_session.commit()
|
||||
logger.error(f"Nomi-sec sync failed: {e}")
|
||||
|
||||
return {
|
||||
'job_id': str(job.id),
|
||||
'status': 'failed',
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
async def get_sync_status(self) -> dict:
|
||||
"""Get synchronization status"""
|
||||
from main import CVE, SigmaRule
|
||||
|
|
|
@ -50,16 +50,18 @@ jobs:
|
|||
timeout_minutes: 30
|
||||
retry_on_failure: true
|
||||
|
||||
# Nomi-sec PoC Sync - Update proof-of-concept data
|
||||
# Nomi-sec PoC Sync - Update proof-of-concept data (OPTIMIZED)
|
||||
nomi_sec_sync:
|
||||
enabled: true
|
||||
schedule: "0 4 * * 1" # Weekly on Monday at 4 AM
|
||||
description: "Sync nomi-sec Proof-of-Concept data"
|
||||
description: "Sync nomi-sec Proof-of-Concept data (optimized)"
|
||||
job_type: "nomi_sec_sync"
|
||||
parameters:
|
||||
batch_size: 50
|
||||
priority: "medium"
|
||||
timeout_minutes: 120
|
||||
batch_size: 100 # Increased batch size
|
||||
max_cves: 1000 # Limit to recent/important CVEs
|
||||
force_resync: false # Skip recently synced CVEs
|
||||
priority: "high" # Increased priority
|
||||
timeout_minutes: 60 # Reduced timeout due to optimizations
|
||||
retry_on_failure: true
|
||||
|
||||
# GitHub PoC Sync - Update GitHub proof-of-concept data
|
||||
|
|
Loading…
Add table
Reference in a new issue