add claude client + generic llm client using langchain

This commit is contained in:
Brendan McDevitt 2025-07-09 18:02:45 -05:00
parent e4a3cc6cb9
commit 20b3a63c78
7 changed files with 1067 additions and 15 deletions

View file

@ -7,6 +7,31 @@ NVD_API_KEY=your_nvd_api_key_here
# Only needs "public_repo" scope for searching public repositories # Only needs "public_repo" scope for searching public repositories
GITHUB_TOKEN=your_github_token_here GITHUB_TOKEN=your_github_token_here
# LLM API Configuration (Optional - for enhanced SIGMA rule generation)
# Choose your preferred LLM provider and configure the corresponding API key
# OpenAI Configuration
# Get your API key at: https://platform.openai.com/api-keys
OPENAI_API_KEY=your_openai_api_key_here
# Anthropic Configuration
# Get your API key at: https://console.anthropic.com/
ANTHROPIC_API_KEY=your_anthropic_api_key_here
# Ollama Configuration (for local models)
# Install Ollama locally: https://ollama.ai/
OLLAMA_BASE_URL=http://localhost:11434
# LLM Provider Selection (optional - auto-detects if not specified)
# Options: openai, anthropic, ollama
LLM_PROVIDER=openai
# LLM Model Selection (optional - uses provider default if not specified)
# OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
# Anthropic: claude-3-5-sonnet-20241022, claude-3-haiku-20240307, claude-3-opus-20240229
# Ollama: llama3.2, codellama, mistral, llama2
LLM_MODEL=gpt-4o-mini
# Database Configuration (Docker Compose will use defaults) # Database Configuration (Docker Compose will use defaults)
# DATABASE_URL=postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db # DATABASE_URL=postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db

221
backend/claude_client.py Normal file
View file

@ -0,0 +1,221 @@
"""
Claude API client for enhanced SIGMA rule generation.
"""
import os
import logging
from typing import Optional, Dict, Any
from anthropic import Anthropic
logger = logging.getLogger(__name__)
class ClaudeClient:
"""Client for interacting with Claude API for SIGMA rule generation."""
def __init__(self, api_key: Optional[str] = None):
"""Initialize Claude client with API key."""
self.api_key = api_key or os.getenv('CLAUDE_API_KEY')
self.client = None
if self.api_key:
try:
self.client = Anthropic(api_key=self.api_key)
logger.info("Claude API client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Claude API client: {e}")
self.client = None
else:
logger.warning("No Claude API key provided. Claude-enhanced rule generation disabled.")
def is_available(self) -> bool:
"""Check if Claude API client is available."""
return self.client is not None
async def generate_sigma_rule(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> Optional[str]:
"""
Generate or enhance a SIGMA rule using Claude API.
Args:
cve_id: CVE identifier
poc_content: Proof-of-concept code content from GitHub
cve_description: CVE description from NVD
existing_rule: Optional existing SIGMA rule to enhance
Returns:
Generated SIGMA rule YAML content or None if failed
"""
if not self.is_available():
logger.warning("Claude API client not available")
return None
try:
# Construct the prompt for Claude
prompt = self._build_sigma_generation_prompt(
cve_id, poc_content, cve_description, existing_rule
)
# Make API call to Claude
response = self.client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=2000,
temperature=0.1,
messages=[
{
"role": "user",
"content": prompt
}
]
)
# Extract the SIGMA rule from response
sigma_rule = self._extract_sigma_rule(response.content[0].text)
logger.info(f"Successfully generated SIGMA rule for {cve_id} using Claude")
return sigma_rule
except Exception as e:
logger.error(f"Failed to generate SIGMA rule for {cve_id} using Claude: {e}")
return None
def _build_sigma_generation_prompt(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> str:
"""Build the prompt for Claude to generate SIGMA rules."""
base_prompt = f"""You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules.
**CVE Information:**
- CVE ID: {cve_id}
- Description: {cve_description}
**Proof-of-Concept Code:**
```
{poc_content[:4000]} # Truncate if too long
```
**Your Task:**
1. Analyze the exploit code to identify:
- Process execution patterns
- File system activities
- Network connections
- Registry modifications
- Command line arguments
- Suspicious behaviors
2. Create a SIGMA rule that:
- Follows proper SIGMA syntax (YAML format)
- Includes appropriate detection logic
- Has relevant metadata (title, description, author, date, references)
- Uses correct field names for the target log source
- Includes proper condition logic
- Maps to relevant MITRE ATT&CK techniques when applicable
3. Focus on detection patterns that would catch this specific exploit in action
**Important Requirements:**
- Output ONLY the SIGMA rule in valid YAML format
- Do not include explanations or comments outside the YAML
- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition
- Make the rule specific enough to detect the exploit but not too narrow to miss variants
- Include relevant tags and MITRE ATT&CK technique mappings"""
if existing_rule:
base_prompt += f"""
**Existing SIGMA Rule (to enhance):**
```yaml
{existing_rule}
```
Please enhance the existing rule with insights from the PoC code analysis."""
return base_prompt
def _extract_sigma_rule(self, response_text: str) -> str:
"""Extract SIGMA rule YAML from Claude's response."""
# Look for YAML content in the response
lines = response_text.split('\n')
yaml_lines = []
in_yaml = False
for line in lines:
if line.strip().startswith('```yaml') or line.strip().startswith('```'):
in_yaml = True
continue
elif line.strip() == '```' and in_yaml:
break
elif in_yaml or line.strip().startswith('title:'):
yaml_lines.append(line)
in_yaml = True
if not yaml_lines:
# If no YAML block found, return the whole response
return response_text.strip()
return '\n'.join(yaml_lines).strip()
async def enhance_existing_rule(self,
existing_rule: str,
poc_content: str,
cve_id: str) -> Optional[str]:
"""
Enhance an existing SIGMA rule with PoC analysis.
Args:
existing_rule: Existing SIGMA rule YAML
poc_content: PoC code content
cve_id: CVE identifier
Returns:
Enhanced SIGMA rule or None if failed
"""
if not self.is_available():
return None
try:
prompt = f"""You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
**CVE ID:** {cve_id}
**PoC Code:**
```
{poc_content[:3000]}
```
**Existing SIGMA Rule:**
```yaml
{existing_rule}
```
**Task:** Enhance the existing rule by:
1. Adding more specific detection patterns found in the PoC
2. Improving the condition logic
3. Adding relevant tags or MITRE ATT&CK mappings
4. Keeping the rule structure intact but making it more effective
Output ONLY the enhanced SIGMA rule in valid YAML format."""
response = self.client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=2000,
temperature=0.1,
messages=[
{
"role": "user",
"content": prompt
}
]
)
enhanced_rule = self._extract_sigma_rule(response.content[0].text)
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
return enhanced_rule
except Exception as e:
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
return None

View file

@ -9,6 +9,7 @@ from datetime import datetime
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import re import re
from llm_client import LLMClient
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -17,10 +18,11 @@ logger = logging.getLogger(__name__)
class EnhancedSigmaGenerator: class EnhancedSigmaGenerator:
"""Enhanced SIGMA rule generator using nomi-sec PoC data""" """Enhanced SIGMA rule generator using nomi-sec PoC data"""
def __init__(self, db_session: Session): def __init__(self, db_session: Session, llm_provider: str = None, llm_model: str = None):
self.db_session = db_session self.db_session = db_session
self.llm_client = LLMClient(provider=llm_provider, model=llm_model)
async def generate_enhanced_rule(self, cve) -> dict: async def generate_enhanced_rule(self, cve, use_llm: bool = True) -> dict:
"""Generate enhanced SIGMA rule for a CVE using PoC data""" """Generate enhanced SIGMA rule for a CVE using PoC data"""
from main import SigmaRule, RuleTemplate from main import SigmaRule, RuleTemplate
@ -33,6 +35,20 @@ class EnhancedSigmaGenerator:
if poc_data: if poc_data:
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0)) best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
# Try LLM-enhanced generation first if enabled and available
rule_content = None
generation_method = "template"
if use_llm and self.llm_client.is_available() and best_poc:
logger.info(f"Attempting LLM-enhanced rule generation for {cve.cve_id} using {self.llm_client.provider}")
rule_content = await self._generate_llm_enhanced_rule(cve, best_poc, poc_data)
if rule_content:
generation_method = f"llm_{self.llm_client.provider}"
# Fallback to template-based generation
if not rule_content:
logger.info(f"Using template-based rule generation for {cve.cve_id}")
# Select appropriate template based on PoC analysis # Select appropriate template based on PoC analysis
template = await self._select_template(cve, best_poc) template = await self._select_template(cve, best_poc)
@ -55,8 +71,8 @@ class EnhancedSigmaGenerator:
'cve_id': cve.cve_id, 'cve_id': cve.cve_id,
'rule_name': f"{cve.cve_id} Enhanced Detection", 'rule_name': f"{cve.cve_id} Enhanced Detection",
'rule_content': rule_content, 'rule_content': rule_content,
'detection_type': template.template_name, 'detection_type': f"{generation_method}_generated",
'log_source': self._extract_log_source(template.template_name), 'log_source': self._extract_log_source_from_content(rule_content),
'confidence_level': confidence_level, 'confidence_level': confidence_level,
'auto_generated': True, 'auto_generated': True,
'exploit_based': len(poc_data) > 0, 'exploit_based': len(poc_data) > 0,
@ -67,7 +83,8 @@ class EnhancedSigmaGenerator:
'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0, 'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0,
'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data), 'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data),
'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0, 'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0,
'source': getattr(cve, 'poc_source', 'nomi_sec') 'source': getattr(cve, 'poc_source', 'nomi_sec'),
'generation_method': generation_method
}, },
'github_repos': [p.get('html_url', '') for p in poc_data], 'github_repos': [p.get('html_url', '') for p in poc_data],
'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)), 'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)),
@ -100,6 +117,135 @@ class EnhancedSigmaGenerator:
logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}") logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}")
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
async def _generate_llm_enhanced_rule(self, cve, best_poc: dict, poc_data: list) -> Optional[str]:
"""Generate SIGMA rule using LLM API with PoC analysis"""
try:
# Get PoC content from the best quality PoC
poc_content = await self._extract_poc_content(best_poc)
if not poc_content:
logger.warning(f"No PoC content available for {cve.cve_id}")
return None
# Generate rule using LLM
rule_content = await self.llm_client.generate_sigma_rule(
cve_id=cve.cve_id,
poc_content=poc_content,
cve_description=cve.description or "",
existing_rule=None
)
if rule_content:
# Validate the generated rule
if self.llm_client.validate_sigma_rule(rule_content):
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
return rule_content
else:
logger.warning(f"Generated rule for {cve.cve_id} failed validation")
return None
return None
except Exception as e:
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
return None
async def _extract_poc_content(self, poc: dict) -> Optional[str]:
"""Extract actual code content from PoC repository"""
try:
import aiohttp
import asyncio
# Get repository information
repo_url = poc.get('html_url', '')
if not repo_url:
return None
# Convert GitHub URL to API URL for repository content
if 'github.com' in repo_url:
# Extract owner and repo from URL
parts = repo_url.rstrip('/').split('/')
if len(parts) >= 2:
owner = parts[-2]
repo = parts[-1]
# Get repository files via GitHub API
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents"
async with aiohttp.ClientSession() as session:
# Add timeout to prevent hanging
timeout = aiohttp.ClientTimeout(total=30)
async with session.get(api_url, timeout=timeout) as response:
if response.status == 200:
contents = await response.json()
# Look for common exploit files
target_files = [
'exploit.py', 'poc.py', 'exploit.c', 'exploit.cpp',
'exploit.java', 'exploit.rb', 'exploit.php',
'exploit.js', 'exploit.sh', 'exploit.ps1',
'README.md', 'main.py', 'index.js'
]
for file_info in contents:
if file_info.get('type') == 'file':
filename = file_info.get('name', '').lower()
# Check if this is a target file
if any(target in filename for target in target_files):
file_url = file_info.get('download_url')
if file_url:
async with session.get(file_url, timeout=timeout) as file_response:
if file_response.status == 200:
content = await file_response.text()
# Limit content size
if len(content) > 10000:
content = content[:10000] + "\n... [content truncated]"
return content
# If no specific exploit file found, return description/README
for file_info in contents:
if file_info.get('type') == 'file':
filename = file_info.get('name', '').lower()
if 'readme' in filename:
file_url = file_info.get('download_url')
if file_url:
async with session.get(file_url, timeout=timeout) as file_response:
if file_response.status == 200:
content = await file_response.text()
return content[:5000] # Smaller limit for README
# Fallback to description and metadata
description = poc.get('description', '')
if description:
return f"Repository Description: {description}"
return None
except Exception as e:
logger.error(f"Error extracting PoC content: {e}")
return None
def _extract_log_source_from_content(self, rule_content: str) -> str:
"""Extract log source from the generated rule content"""
try:
import yaml
parsed = yaml.safe_load(rule_content)
logsource = parsed.get('logsource', {})
category = logsource.get('category', '')
product = logsource.get('product', '')
if category:
return category
elif product:
return product
else:
return 'generic'
except Exception:
return 'generic'
async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]: async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]:
"""Select the most appropriate template based on CVE and PoC analysis""" """Select the most appropriate template based on CVE and PoC analysis"""
from main import RuleTemplate from main import RuleTemplate

398
backend/llm_client.py Normal file
View file

@ -0,0 +1,398 @@
"""
LangChain-based LLM client for enhanced SIGMA rule generation.
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
"""
import os
import logging
from typing import Optional, Dict, Any, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
import yaml
logger = logging.getLogger(__name__)
class LLMClient:
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
SUPPORTED_PROVIDERS = {
'openai': {
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
'env_key': 'OPENAI_API_KEY',
'default_model': 'gpt-4o-mini'
},
'anthropic': {
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
'env_key': 'ANTHROPIC_API_KEY',
'default_model': 'claude-3-5-sonnet-20241022'
},
'ollama': {
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
'env_key': 'OLLAMA_BASE_URL',
'default_model': 'llama3.2'
}
}
def __init__(self, provider: str = None, model: str = None):
"""Initialize LLM client with specified provider and model."""
self.provider = provider or self._detect_provider()
self.model = model or self._get_default_model(self.provider)
self.llm = None
self.output_parser = StrOutputParser()
self._initialize_llm()
def _detect_provider(self) -> str:
"""Auto-detect available LLM provider based on environment variables."""
# Check for API keys in order of preference
if os.getenv('ANTHROPIC_API_KEY'):
return 'anthropic'
elif os.getenv('OPENAI_API_KEY'):
return 'openai'
elif os.getenv('OLLAMA_BASE_URL'):
return 'ollama'
else:
# Default to OpenAI if no keys found
return 'openai'
def _get_default_model(self, provider: str) -> str:
"""Get default model for the specified provider."""
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
def _initialize_llm(self):
"""Initialize the LLM based on provider and model."""
try:
if self.provider == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
logger.warning("OpenAI API key not found")
return
self.llm = ChatOpenAI(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
logger.warning("Anthropic API key not found")
return
self.llm = ChatAnthropic(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
self.llm = Ollama(
model=self.model,
base_url=base_url,
temperature=0.1
)
if self.llm:
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
else:
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
except Exception as e:
logger.error(f"Error initializing LLM client: {e}")
self.llm = None
def is_available(self) -> bool:
"""Check if LLM client is available and configured."""
return self.llm is not None
def get_provider_info(self) -> Dict[str, Any]:
"""Get information about the current provider and configuration."""
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
return {
'provider': self.provider,
'model': self.model,
'available': self.is_available(),
'supported_models': provider_info.get('models', []),
'env_key': provider_info.get('env_key', ''),
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
}
async def generate_sigma_rule(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> Optional[str]:
"""
Generate or enhance a SIGMA rule using the configured LLM.
Args:
cve_id: CVE identifier
poc_content: Proof-of-concept code content from GitHub
cve_description: CVE description from NVD
existing_rule: Optional existing SIGMA rule to enhance
Returns:
Generated SIGMA rule YAML content or None if failed
"""
if not self.is_available():
logger.warning("LLM client not available")
return None
try:
# Create the prompt template
prompt = self._build_sigma_generation_prompt(
cve_id, poc_content, cve_description, existing_rule
)
# Create the chain
chain = prompt | self.llm | self.output_parser
# Generate the response
response = await chain.ainvoke({
"cve_id": cve_id,
"poc_content": poc_content[:4000], # Truncate if too long
"cve_description": cve_description,
"existing_rule": existing_rule or "None"
})
# Extract the SIGMA rule from response
sigma_rule = self._extract_sigma_rule(response)
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
return sigma_rule
except Exception as e:
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
return None
def _build_sigma_generation_prompt(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
"""Build the prompt template for SIGMA rule generation."""
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules.
**Your Task:**
1. Analyze the exploit code to identify:
- Process execution patterns
- File system activities
- Network connections
- Registry modifications
- Command line arguments
- Suspicious behaviors
2. Create a SIGMA rule that:
- Follows proper SIGMA syntax (YAML format)
- Includes appropriate detection logic
- Has relevant metadata (title, description, author, date, references)
- Uses correct field names for the target log source
- Includes proper condition logic
- Maps to relevant MITRE ATT&CK techniques when applicable
3. Focus on detection patterns that would catch this specific exploit in action
**Important Requirements:**
- Output ONLY the SIGMA rule in valid YAML format
- Do not include explanations or comments outside the YAML
- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition
- Make the rule specific enough to detect the exploit but not too narrow to miss variants
- Include relevant tags and MITRE ATT&CK technique mappings"""
if existing_rule:
user_template = """**CVE Information:**
- CVE ID: {cve_id}
- Description: {cve_description}
**Proof-of-Concept Code:**
```
{poc_content}
```
**Existing SIGMA Rule (to enhance):**
```yaml
{existing_rule}
```
Please enhance the existing rule with insights from the PoC code analysis."""
else:
user_template = """**CVE Information:**
- CVE ID: {cve_id}
- Description: {cve_description}
**Proof-of-Concept Code:**
```
{poc_content}
```
Please create a new SIGMA rule based on the PoC code analysis."""
return ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
def _extract_sigma_rule(self, response_text: str) -> str:
"""Extract SIGMA rule YAML from LLM response."""
# Look for YAML content in the response
lines = response_text.split('\n')
yaml_lines = []
in_yaml = False
for line in lines:
if line.strip().startswith('```yaml') or line.strip().startswith('```'):
in_yaml = True
continue
elif line.strip() == '```' and in_yaml:
break
elif in_yaml or line.strip().startswith('title:'):
yaml_lines.append(line)
in_yaml = True
if not yaml_lines:
# If no YAML block found, return the whole response
return response_text.strip()
return '\n'.join(yaml_lines).strip()
async def enhance_existing_rule(self,
existing_rule: str,
poc_content: str,
cve_id: str) -> Optional[str]:
"""
Enhance an existing SIGMA rule with PoC analysis.
Args:
existing_rule: Existing SIGMA rule YAML
poc_content: PoC code content
cve_id: CVE identifier
Returns:
Enhanced SIGMA rule or None if failed
"""
if not self.is_available():
return None
try:
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
**Task:** Enhance the existing rule by:
1. Adding more specific detection patterns found in the PoC
2. Improving the condition logic
3. Adding relevant tags or MITRE ATT&CK mappings
4. Keeping the rule structure intact but making it more effective
Output ONLY the enhanced SIGMA rule in valid YAML format."""
user_template = """**CVE ID:** {cve_id}
**PoC Code:**
```
{poc_content}
```
**Existing SIGMA Rule:**
```yaml
{existing_rule}
```"""
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
chain = prompt | self.llm | self.output_parser
response = await chain.ainvoke({
"cve_id": cve_id,
"poc_content": poc_content[:3000],
"existing_rule": existing_rule
})
enhanced_rule = self._extract_sigma_rule(response)
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
return enhanced_rule
except Exception as e:
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
return None
def validate_sigma_rule(self, rule_content: str) -> bool:
"""Validate that the generated rule is syntactically correct SIGMA."""
try:
# Parse as YAML
parsed = yaml.safe_load(rule_content)
# Check required fields
required_fields = ['title', 'id', 'description', 'logsource', 'detection']
for field in required_fields:
if field not in parsed:
logger.warning(f"Missing required field: {field}")
return False
# Check detection structure
detection = parsed.get('detection', {})
if not isinstance(detection, dict):
logger.warning("Detection field must be a dictionary")
return False
# Should have at least one selection and a condition
if 'condition' not in detection:
logger.warning("Detection must have a condition")
return False
# Check logsource structure
logsource = parsed.get('logsource', {})
if not isinstance(logsource, dict):
logger.warning("Logsource field must be a dictionary")
return False
logger.info("SIGMA rule validation passed")
return True
except yaml.YAMLError as e:
logger.warning(f"YAML parsing error: {e}")
return False
except Exception as e:
logger.warning(f"Rule validation error: {e}")
return False
@classmethod
def get_available_providers(cls) -> List[Dict[str, Any]]:
"""Get list of available LLM providers and their configuration status."""
providers = []
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
env_key = provider_info.get('env_key', '')
api_key_configured = bool(os.getenv(env_key))
providers.append({
'name': provider_name,
'models': provider_info.get('models', []),
'default_model': provider_info.get('default_model', ''),
'env_key': env_key,
'api_key_configured': api_key_configured,
'available': api_key_configured or provider_name == 'ollama'
})
return providers
def switch_provider(self, provider: str, model: str = None):
"""Switch to a different LLM provider and model."""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(f"Unsupported provider: {provider}")
self.provider = provider
self.model = model or self._get_default_model(provider)
self._initialize_llm()
logger.info(f"Switched to provider: {provider} with model: {self.model}")

View file

@ -1414,6 +1414,171 @@ async def regenerate_sigma_rules(background_tasks: BackgroundTasks,
"force": request.force "force": request.force
} }
@app.post("/api/llm-enhanced-rules")
async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""Generate SIGMA rules using LLM API for enhanced analysis"""
# Parse request parameters
cve_id = request.get('cve_id')
force = request.get('force', False)
llm_provider = request.get('provider', os.getenv('LLM_PROVIDER'))
llm_model = request.get('model', os.getenv('LLM_MODEL'))
# Validation
if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id):
raise HTTPException(status_code=400, detail="Invalid CVE ID format")
async def llm_generation_task():
"""Background task for LLM-enhanced rule generation"""
try:
from enhanced_sigma_generator import EnhancedSigmaGenerator
generator = EnhancedSigmaGenerator(db, llm_provider, llm_model)
# Process specific CVE or all CVEs with PoC data
if cve_id:
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.error(f"CVE {cve_id} not found")
return
cves_to_process = [cve]
else:
# Process CVEs with PoC data that either have no rules or force update
query = db.query(CVE).filter(CVE.poc_count > 0)
if not force:
# Only process CVEs without existing LLM-generated rules
existing_llm_rules = db.query(SigmaRule).filter(
SigmaRule.detection_type.like('llm_%')
).all()
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids]
else:
cves_to_process = query.all()
logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}")
rules_generated = 0
rules_updated = 0
failures = 0
for cve in cves_to_process:
try:
# Check if CVE has sufficient PoC data
if not cve.poc_data or not cve.poc_count:
logger.debug(f"Skipping {cve.cve_id} - no PoC data")
continue
# Generate LLM-enhanced rule
result = await generator.generate_enhanced_rule(cve, use_llm=True)
if result.get('success'):
if result.get('updated'):
rules_updated += 1
else:
rules_generated += 1
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
else:
failures += 1
logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}")
except Exception as e:
failures += 1
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
continue
logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
except Exception as e:
logger.error(f"LLM-enhanced rule generation failed: {e}")
import traceback
traceback.print_exc()
background_tasks.add_task(llm_generation_task)
return {
"message": "LLM-enhanced SIGMA rule generation started",
"status": "started",
"cve_id": cve_id,
"force": force,
"provider": llm_provider,
"model": llm_model,
"note": "Requires appropriate LLM API key to be set"
}
@app.get("/api/llm-status")
async def get_llm_status():
"""Check LLM API availability status"""
try:
from llm_client import LLMClient
# Get current provider configuration
provider = os.getenv('LLM_PROVIDER')
model = os.getenv('LLM_MODEL')
client = LLMClient(provider=provider, model=model)
provider_info = client.get_provider_info()
# Get all available providers
all_providers = LLMClient.get_available_providers()
return {
"current_provider": provider_info,
"available_providers": all_providers,
"status": "ready" if client.is_available() else "unavailable"
}
except Exception as e:
logger.error(f"Error checking LLM status: {e}")
return {
"current_provider": {"provider": "unknown", "available": False},
"available_providers": [],
"status": "error",
"error": str(e)
}
@app.post("/api/llm-switch")
async def switch_llm_provider(request: dict):
"""Switch LLM provider and model"""
try:
from llm_client import LLMClient
provider = request.get('provider')
model = request.get('model')
if not provider:
raise HTTPException(status_code=400, detail="Provider is required")
# Validate provider
if provider not in LLMClient.SUPPORTED_PROVIDERS:
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
# Test the new configuration
client = LLMClient(provider=provider, model=model)
if not client.is_available():
raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured")
# Update environment variables (note: this only affects the current session)
os.environ['LLM_PROVIDER'] = provider
if model:
os.environ['LLM_MODEL'] = model
provider_info = client.get_provider_info()
return {
"message": f"Switched to {provider}",
"provider_info": provider_info,
"status": "success"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error switching LLM provider: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/cancel-job/{job_id}") @app.post("/api/cancel-job/{job_id}")
async def cancel_job(job_id: str, db: Session = Depends(get_db)): async def cancel_job(job_id: str, db: Session = Depends(get_db)):
"""Cancel a running job""" """Cancel a running job"""

View file

@ -15,3 +15,10 @@ lxml==4.9.3
aiohttp==3.9.1 aiohttp==3.9.1
aiofiles aiofiles
pyyaml==6.0.1 pyyaml==6.0.1
langchain==0.2.0
langchain-openai==0.1.17
langchain-anthropic==0.1.15
langchain-community==0.2.0
langchain-core>=0.2.20
openai>=1.32.0
anthropic==0.40.0

View file

@ -21,6 +21,7 @@ function App() {
const [gitHubPocStats, setGitHubPocStats] = useState({}); const [gitHubPocStats, setGitHubPocStats] = useState({});
const [bulkProcessing, setBulkProcessing] = useState(false); const [bulkProcessing, setBulkProcessing] = useState(false);
const [hasRunningJobs, setHasRunningJobs] = useState(false); const [hasRunningJobs, setHasRunningJobs] = useState(false);
const [llmStatus, setLlmStatus] = useState({});
useEffect(() => { useEffect(() => {
fetchData(); fetchData();
@ -29,14 +30,15 @@ function App() {
const fetchData = async () => { const fetchData = async () => {
try { try {
setLoading(true); setLoading(true);
const [cvesRes, rulesRes, statsRes, bulkJobsRes, bulkStatusRes, pocStatsRes, githubPocStatsRes] = await Promise.all([ const [cvesRes, rulesRes, statsRes, bulkJobsRes, bulkStatusRes, pocStatsRes, githubPocStatsRes, llmStatusRes] = await Promise.all([
axios.get(`${API_BASE_URL}/api/cves`), axios.get(`${API_BASE_URL}/api/cves`),
axios.get(`${API_BASE_URL}/api/sigma-rules`), axios.get(`${API_BASE_URL}/api/sigma-rules`),
axios.get(`${API_BASE_URL}/api/stats`), axios.get(`${API_BASE_URL}/api/stats`),
axios.get(`${API_BASE_URL}/api/bulk-jobs`), axios.get(`${API_BASE_URL}/api/bulk-jobs`),
axios.get(`${API_BASE_URL}/api/bulk-status`), axios.get(`${API_BASE_URL}/api/bulk-status`),
axios.get(`${API_BASE_URL}/api/poc-stats`), axios.get(`${API_BASE_URL}/api/poc-stats`),
axios.get(`${API_BASE_URL}/api/github-poc-stats`).catch(err => ({ data: {} })) axios.get(`${API_BASE_URL}/api/github-poc-stats`).catch(err => ({ data: {} })),
axios.get(`${API_BASE_URL}/api/llm-status`).catch(err => ({ data: {} }))
]); ]);
setCves(cvesRes.data); setCves(cvesRes.data);
@ -46,6 +48,7 @@ function App() {
setBulkStatus(bulkStatusRes.data); setBulkStatus(bulkStatusRes.data);
setPocStats(pocStatsRes.data); setPocStats(pocStatsRes.data);
setGitHubPocStats(githubPocStatsRes.data); setGitHubPocStats(githubPocStatsRes.data);
setLlmStatus(llmStatusRes.data);
// Update running jobs state // Update running jobs state
const runningJobs = bulkJobsRes.data.filter(job => job.status === 'running' || job.status === 'pending'); const runningJobs = bulkJobsRes.data.filter(job => job.status === 'running' || job.status === 'pending');
@ -166,6 +169,32 @@ function App() {
} }
}; };
const generateLlmRules = async (force = false) => {
try {
const response = await axios.post(`${API_BASE_URL}/api/llm-enhanced-rules`, {
force: force
});
console.log('LLM rule generation response:', response.data);
fetchData();
} catch (error) {
console.error('Error generating LLM-enhanced rules:', error);
}
};
const switchLlmProvider = async (provider, model) => {
try {
const response = await axios.post(`${API_BASE_URL}/api/llm-switch`, {
provider: provider,
model: model
});
console.log('LLM provider switch response:', response.data);
fetchData(); // Refresh to get updated status
} catch (error) {
console.error('Error switching LLM provider:', error);
alert('Failed to switch LLM provider. Please check configuration.');
}
};
const getSeverityColor = (severity) => { const getSeverityColor = (severity) => {
switch (severity?.toLowerCase()) { switch (severity?.toLowerCase()) {
case 'critical': return 'bg-red-100 text-red-800'; case 'critical': return 'bg-red-100 text-red-800';
@ -197,6 +226,9 @@ function App() {
<p className="text-3xl font-bold text-green-600">{stats.total_sigma_rules || 0}</p> <p className="text-3xl font-bold text-green-600">{stats.total_sigma_rules || 0}</p>
<p className="text-sm text-gray-500">Nomi-sec: {stats.nomi_sec_rules || 0}</p> <p className="text-sm text-gray-500">Nomi-sec: {stats.nomi_sec_rules || 0}</p>
<p className="text-sm text-gray-500">GitHub PoCs: {gitHubPocStats.github_poc_rules || 0}</p> <p className="text-sm text-gray-500">GitHub PoCs: {gitHubPocStats.github_poc_rules || 0}</p>
<p className={`text-sm ${llmStatus.status === 'ready' ? 'text-green-600' : 'text-red-500'}`}>
LLM: {llmStatus.current_provider?.provider || 'Not Available'}
</p>
</div> </div>
<div className="bg-white p-6 rounded-lg shadow"> <div className="bg-white p-6 rounded-lg shadow">
<h3 className="text-lg font-medium text-gray-900">CVEs with PoCs</h3> <h3 className="text-lg font-medium text-gray-900">CVEs with PoCs</h3>
@ -219,7 +251,7 @@ function App() {
{/* Bulk Processing Controls */} {/* Bulk Processing Controls */}
<div className="bg-white rounded-lg shadow p-6"> <div className="bg-white rounded-lg shadow p-6">
<h2 className="text-xl font-bold text-gray-900 mb-4">Bulk Processing</h2> <h2 className="text-xl font-bold text-gray-900 mb-4">Bulk Processing</h2>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4"> <div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-5 gap-4">
<button <button
onClick={() => startBulkSeed(2002)} onClick={() => startBulkSeed(2002)}
disabled={hasRunningJobs} disabled={hasRunningJobs}
@ -275,6 +307,64 @@ function App() {
> >
{hasRunningJobs ? 'Processing...' : 'Regenerate Rules'} {hasRunningJobs ? 'Processing...' : 'Regenerate Rules'}
</button> </button>
<button
onClick={() => generateLlmRules()}
disabled={hasRunningJobs || llmStatus.status !== 'ready'}
className={`px-4 py-2 rounded-md text-white ${
hasRunningJobs || llmStatus.status !== 'ready'
? 'bg-gray-400 cursor-not-allowed'
: 'bg-violet-600 hover:bg-violet-700'
}`}
title={llmStatus.status !== 'ready' ? 'LLM not configured' : ''}
>
{hasRunningJobs ? 'Processing...' : 'Generate LLM Rules'}
</button>
</div>
</div>
{/* LLM Configuration */}
<div className="bg-white rounded-lg shadow p-6">
<h2 className="text-xl font-bold text-gray-900 mb-4">LLM Configuration</h2>
<div className="grid grid-cols-1 md:grid-cols-2 gap-6">
<div>
<h3 className="text-lg font-medium text-gray-900 mb-2">Current Provider</h3>
<div className="space-y-2">
<p className="text-sm text-gray-600">
Provider: <span className="font-medium">{llmStatus.current_provider?.provider || 'Not configured'}</span>
</p>
<p className="text-sm text-gray-600">
Model: <span className="font-medium">{llmStatus.current_provider?.model || 'Not configured'}</span>
</p>
<p className={`text-sm ${llmStatus.status === 'ready' ? 'text-green-600' : 'text-red-500'}`}>
Status: <span className="font-medium">{llmStatus.status || 'Unknown'}</span>
</p>
</div>
</div>
<div>
<h3 className="text-lg font-medium text-gray-900 mb-2">Available Providers</h3>
<div className="space-y-2">
{llmStatus.available_providers?.map(provider => (
<div key={provider.name} className="flex items-center justify-between p-2 bg-gray-50 rounded">
<div>
<span className="font-medium">{provider.name}</span>
<span className={`ml-2 text-xs px-2 py-1 rounded ${
provider.available ? 'bg-green-100 text-green-800' : 'bg-red-100 text-red-800'
}`}>
{provider.available ? 'Available' : 'Not configured'}
</span>
</div>
{provider.available && provider.name !== llmStatus.current_provider?.provider && (
<button
onClick={() => switchLlmProvider(provider.name, provider.default_model)}
className="text-xs bg-blue-600 hover:bg-blue-700 text-white px-2 py-1 rounded"
>
Switch
</button>
)}
</div>
))}
</div>
</div>
</div> </div>
</div> </div>