add claude client + generic llm client using langchain
This commit is contained in:
parent
e4a3cc6cb9
commit
20b3a63c78
7 changed files with 1067 additions and 15 deletions
25
.env.example
25
.env.example
|
@ -7,6 +7,31 @@ NVD_API_KEY=your_nvd_api_key_here
|
||||||
# Only needs "public_repo" scope for searching public repositories
|
# Only needs "public_repo" scope for searching public repositories
|
||||||
GITHUB_TOKEN=your_github_token_here
|
GITHUB_TOKEN=your_github_token_here
|
||||||
|
|
||||||
|
# LLM API Configuration (Optional - for enhanced SIGMA rule generation)
|
||||||
|
# Choose your preferred LLM provider and configure the corresponding API key
|
||||||
|
|
||||||
|
# OpenAI Configuration
|
||||||
|
# Get your API key at: https://platform.openai.com/api-keys
|
||||||
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
|
|
||||||
|
# Anthropic Configuration
|
||||||
|
# Get your API key at: https://console.anthropic.com/
|
||||||
|
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
||||||
|
|
||||||
|
# Ollama Configuration (for local models)
|
||||||
|
# Install Ollama locally: https://ollama.ai/
|
||||||
|
OLLAMA_BASE_URL=http://localhost:11434
|
||||||
|
|
||||||
|
# LLM Provider Selection (optional - auto-detects if not specified)
|
||||||
|
# Options: openai, anthropic, ollama
|
||||||
|
LLM_PROVIDER=openai
|
||||||
|
|
||||||
|
# LLM Model Selection (optional - uses provider default if not specified)
|
||||||
|
# OpenAI: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
|
||||||
|
# Anthropic: claude-3-5-sonnet-20241022, claude-3-haiku-20240307, claude-3-opus-20240229
|
||||||
|
# Ollama: llama3.2, codellama, mistral, llama2
|
||||||
|
LLM_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
# Database Configuration (Docker Compose will use defaults)
|
# Database Configuration (Docker Compose will use defaults)
|
||||||
# DATABASE_URL=postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db
|
# DATABASE_URL=postgresql://cve_user:cve_password@localhost:5432/cve_sigma_db
|
||||||
|
|
||||||
|
|
221
backend/claude_client.py
Normal file
221
backend/claude_client.py
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
"""
|
||||||
|
Claude API client for enhanced SIGMA rule generation.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from anthropic import Anthropic
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ClaudeClient:
|
||||||
|
"""Client for interacting with Claude API for SIGMA rule generation."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
|
"""Initialize Claude client with API key."""
|
||||||
|
self.api_key = api_key or os.getenv('CLAUDE_API_KEY')
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
try:
|
||||||
|
self.client = Anthropic(api_key=self.api_key)
|
||||||
|
logger.info("Claude API client initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Claude API client: {e}")
|
||||||
|
self.client = None
|
||||||
|
else:
|
||||||
|
logger.warning("No Claude API key provided. Claude-enhanced rule generation disabled.")
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if Claude API client is available."""
|
||||||
|
return self.client is not None
|
||||||
|
|
||||||
|
async def generate_sigma_rule(self,
|
||||||
|
cve_id: str,
|
||||||
|
poc_content: str,
|
||||||
|
cve_description: str,
|
||||||
|
existing_rule: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Generate or enhance a SIGMA rule using Claude API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cve_id: CVE identifier
|
||||||
|
poc_content: Proof-of-concept code content from GitHub
|
||||||
|
cve_description: CVE description from NVD
|
||||||
|
existing_rule: Optional existing SIGMA rule to enhance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated SIGMA rule YAML content or None if failed
|
||||||
|
"""
|
||||||
|
if not self.is_available():
|
||||||
|
logger.warning("Claude API client not available")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Construct the prompt for Claude
|
||||||
|
prompt = self._build_sigma_generation_prompt(
|
||||||
|
cve_id, poc_content, cve_description, existing_rule
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make API call to Claude
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=2000,
|
||||||
|
temperature=0.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the SIGMA rule from response
|
||||||
|
sigma_rule = self._extract_sigma_rule(response.content[0].text)
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated SIGMA rule for {cve_id} using Claude")
|
||||||
|
return sigma_rule
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate SIGMA rule for {cve_id} using Claude: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_sigma_generation_prompt(self,
|
||||||
|
cve_id: str,
|
||||||
|
poc_content: str,
|
||||||
|
cve_description: str,
|
||||||
|
existing_rule: Optional[str] = None) -> str:
|
||||||
|
"""Build the prompt for Claude to generate SIGMA rules."""
|
||||||
|
|
||||||
|
base_prompt = f"""You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules.
|
||||||
|
|
||||||
|
**CVE Information:**
|
||||||
|
- CVE ID: {cve_id}
|
||||||
|
- Description: {cve_description}
|
||||||
|
|
||||||
|
**Proof-of-Concept Code:**
|
||||||
|
```
|
||||||
|
{poc_content[:4000]} # Truncate if too long
|
||||||
|
```
|
||||||
|
|
||||||
|
**Your Task:**
|
||||||
|
1. Analyze the exploit code to identify:
|
||||||
|
- Process execution patterns
|
||||||
|
- File system activities
|
||||||
|
- Network connections
|
||||||
|
- Registry modifications
|
||||||
|
- Command line arguments
|
||||||
|
- Suspicious behaviors
|
||||||
|
|
||||||
|
2. Create a SIGMA rule that:
|
||||||
|
- Follows proper SIGMA syntax (YAML format)
|
||||||
|
- Includes appropriate detection logic
|
||||||
|
- Has relevant metadata (title, description, author, date, references)
|
||||||
|
- Uses correct field names for the target log source
|
||||||
|
- Includes proper condition logic
|
||||||
|
- Maps to relevant MITRE ATT&CK techniques when applicable
|
||||||
|
|
||||||
|
3. Focus on detection patterns that would catch this specific exploit in action
|
||||||
|
|
||||||
|
**Important Requirements:**
|
||||||
|
- Output ONLY the SIGMA rule in valid YAML format
|
||||||
|
- Do not include explanations or comments outside the YAML
|
||||||
|
- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition
|
||||||
|
- Make the rule specific enough to detect the exploit but not too narrow to miss variants
|
||||||
|
- Include relevant tags and MITRE ATT&CK technique mappings"""
|
||||||
|
|
||||||
|
if existing_rule:
|
||||||
|
base_prompt += f"""
|
||||||
|
|
||||||
|
**Existing SIGMA Rule (to enhance):**
|
||||||
|
```yaml
|
||||||
|
{existing_rule}
|
||||||
|
```
|
||||||
|
|
||||||
|
Please enhance the existing rule with insights from the PoC code analysis."""
|
||||||
|
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
def _extract_sigma_rule(self, response_text: str) -> str:
|
||||||
|
"""Extract SIGMA rule YAML from Claude's response."""
|
||||||
|
# Look for YAML content in the response
|
||||||
|
lines = response_text.split('\n')
|
||||||
|
yaml_lines = []
|
||||||
|
in_yaml = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().startswith('```yaml') or line.strip().startswith('```'):
|
||||||
|
in_yaml = True
|
||||||
|
continue
|
||||||
|
elif line.strip() == '```' and in_yaml:
|
||||||
|
break
|
||||||
|
elif in_yaml or line.strip().startswith('title:'):
|
||||||
|
yaml_lines.append(line)
|
||||||
|
in_yaml = True
|
||||||
|
|
||||||
|
if not yaml_lines:
|
||||||
|
# If no YAML block found, return the whole response
|
||||||
|
return response_text.strip()
|
||||||
|
|
||||||
|
return '\n'.join(yaml_lines).strip()
|
||||||
|
|
||||||
|
async def enhance_existing_rule(self,
|
||||||
|
existing_rule: str,
|
||||||
|
poc_content: str,
|
||||||
|
cve_id: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Enhance an existing SIGMA rule with PoC analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_rule: Existing SIGMA rule YAML
|
||||||
|
poc_content: PoC code content
|
||||||
|
cve_id: CVE identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced SIGMA rule or None if failed
|
||||||
|
"""
|
||||||
|
if not self.is_available():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = f"""You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
|
||||||
|
|
||||||
|
**CVE ID:** {cve_id}
|
||||||
|
|
||||||
|
**PoC Code:**
|
||||||
|
```
|
||||||
|
{poc_content[:3000]}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Existing SIGMA Rule:**
|
||||||
|
```yaml
|
||||||
|
{existing_rule}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Task:** Enhance the existing rule by:
|
||||||
|
1. Adding more specific detection patterns found in the PoC
|
||||||
|
2. Improving the condition logic
|
||||||
|
3. Adding relevant tags or MITRE ATT&CK mappings
|
||||||
|
4. Keeping the rule structure intact but making it more effective
|
||||||
|
|
||||||
|
Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
||||||
|
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=2000,
|
||||||
|
temperature=0.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
enhanced_rule = self._extract_sigma_rule(response.content[0].text)
|
||||||
|
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
|
||||||
|
return enhanced_rule
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
|
||||||
|
return None
|
|
@ -9,6 +9,7 @@ from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
import re
|
import re
|
||||||
|
from llm_client import LLMClient
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
@ -17,10 +18,11 @@ logger = logging.getLogger(__name__)
|
||||||
class EnhancedSigmaGenerator:
|
class EnhancedSigmaGenerator:
|
||||||
"""Enhanced SIGMA rule generator using nomi-sec PoC data"""
|
"""Enhanced SIGMA rule generator using nomi-sec PoC data"""
|
||||||
|
|
||||||
def __init__(self, db_session: Session):
|
def __init__(self, db_session: Session, llm_provider: str = None, llm_model: str = None):
|
||||||
self.db_session = db_session
|
self.db_session = db_session
|
||||||
|
self.llm_client = LLMClient(provider=llm_provider, model=llm_model)
|
||||||
|
|
||||||
async def generate_enhanced_rule(self, cve) -> dict:
|
async def generate_enhanced_rule(self, cve, use_llm: bool = True) -> dict:
|
||||||
"""Generate enhanced SIGMA rule for a CVE using PoC data"""
|
"""Generate enhanced SIGMA rule for a CVE using PoC data"""
|
||||||
from main import SigmaRule, RuleTemplate
|
from main import SigmaRule, RuleTemplate
|
||||||
|
|
||||||
|
@ -33,15 +35,29 @@ class EnhancedSigmaGenerator:
|
||||||
if poc_data:
|
if poc_data:
|
||||||
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
|
best_poc = max(poc_data, key=lambda x: x.get('quality_analysis', {}).get('quality_score', 0))
|
||||||
|
|
||||||
# Select appropriate template based on PoC analysis
|
# Try LLM-enhanced generation first if enabled and available
|
||||||
template = await self._select_template(cve, best_poc)
|
rule_content = None
|
||||||
|
generation_method = "template"
|
||||||
|
|
||||||
if not template:
|
if use_llm and self.llm_client.is_available() and best_poc:
|
||||||
logger.warning(f"No suitable template found for {cve.cve_id}")
|
logger.info(f"Attempting LLM-enhanced rule generation for {cve.cve_id} using {self.llm_client.provider}")
|
||||||
return {'success': False, 'error': 'No suitable template'}
|
rule_content = await self._generate_llm_enhanced_rule(cve, best_poc, poc_data)
|
||||||
|
if rule_content:
|
||||||
|
generation_method = f"llm_{self.llm_client.provider}"
|
||||||
|
|
||||||
# Generate rule content
|
# Fallback to template-based generation
|
||||||
rule_content = await self._generate_rule_content(cve, template, poc_data)
|
if not rule_content:
|
||||||
|
logger.info(f"Using template-based rule generation for {cve.cve_id}")
|
||||||
|
|
||||||
|
# Select appropriate template based on PoC analysis
|
||||||
|
template = await self._select_template(cve, best_poc)
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
logger.warning(f"No suitable template found for {cve.cve_id}")
|
||||||
|
return {'success': False, 'error': 'No suitable template'}
|
||||||
|
|
||||||
|
# Generate rule content
|
||||||
|
rule_content = await self._generate_rule_content(cve, template, poc_data)
|
||||||
|
|
||||||
# Calculate confidence level
|
# Calculate confidence level
|
||||||
confidence_level = self._calculate_confidence_level(cve, poc_data)
|
confidence_level = self._calculate_confidence_level(cve, poc_data)
|
||||||
|
@ -55,8 +71,8 @@ class EnhancedSigmaGenerator:
|
||||||
'cve_id': cve.cve_id,
|
'cve_id': cve.cve_id,
|
||||||
'rule_name': f"{cve.cve_id} Enhanced Detection",
|
'rule_name': f"{cve.cve_id} Enhanced Detection",
|
||||||
'rule_content': rule_content,
|
'rule_content': rule_content,
|
||||||
'detection_type': template.template_name,
|
'detection_type': f"{generation_method}_generated",
|
||||||
'log_source': self._extract_log_source(template.template_name),
|
'log_source': self._extract_log_source_from_content(rule_content),
|
||||||
'confidence_level': confidence_level,
|
'confidence_level': confidence_level,
|
||||||
'auto_generated': True,
|
'auto_generated': True,
|
||||||
'exploit_based': len(poc_data) > 0,
|
'exploit_based': len(poc_data) > 0,
|
||||||
|
@ -67,7 +83,8 @@ class EnhancedSigmaGenerator:
|
||||||
'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0,
|
'best_poc_quality': best_poc.get('quality_analysis', {}).get('quality_score', 0) if best_poc else 0,
|
||||||
'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data),
|
'total_stars': sum(p.get('stargazers_count', 0) for p in poc_data),
|
||||||
'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0,
|
'avg_stars': sum(p.get('stargazers_count', 0) for p in poc_data) / len(poc_data) if poc_data else 0,
|
||||||
'source': getattr(cve, 'poc_source', 'nomi_sec')
|
'source': getattr(cve, 'poc_source', 'nomi_sec'),
|
||||||
|
'generation_method': generation_method
|
||||||
},
|
},
|
||||||
'github_repos': [p.get('html_url', '') for p in poc_data],
|
'github_repos': [p.get('html_url', '') for p in poc_data],
|
||||||
'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)),
|
'exploit_indicators': json.dumps(self._combine_exploit_indicators(poc_data)),
|
||||||
|
@ -100,6 +117,135 @@ class EnhancedSigmaGenerator:
|
||||||
logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}")
|
logger.error(f"Error generating enhanced rule for {cve.cve_id}: {e}")
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
|
async def _generate_llm_enhanced_rule(self, cve, best_poc: dict, poc_data: list) -> Optional[str]:
|
||||||
|
"""Generate SIGMA rule using LLM API with PoC analysis"""
|
||||||
|
try:
|
||||||
|
# Get PoC content from the best quality PoC
|
||||||
|
poc_content = await self._extract_poc_content(best_poc)
|
||||||
|
if not poc_content:
|
||||||
|
logger.warning(f"No PoC content available for {cve.cve_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate rule using LLM
|
||||||
|
rule_content = await self.llm_client.generate_sigma_rule(
|
||||||
|
cve_id=cve.cve_id,
|
||||||
|
poc_content=poc_content,
|
||||||
|
cve_description=cve.description or "",
|
||||||
|
existing_rule=None
|
||||||
|
)
|
||||||
|
|
||||||
|
if rule_content:
|
||||||
|
# Validate the generated rule
|
||||||
|
if self.llm_client.validate_sigma_rule(rule_content):
|
||||||
|
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
|
||||||
|
return rule_content
|
||||||
|
else:
|
||||||
|
logger.warning(f"Generated rule for {cve.cve_id} failed validation")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _extract_poc_content(self, poc: dict) -> Optional[str]:
|
||||||
|
"""Extract actual code content from PoC repository"""
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Get repository information
|
||||||
|
repo_url = poc.get('html_url', '')
|
||||||
|
if not repo_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert GitHub URL to API URL for repository content
|
||||||
|
if 'github.com' in repo_url:
|
||||||
|
# Extract owner and repo from URL
|
||||||
|
parts = repo_url.rstrip('/').split('/')
|
||||||
|
if len(parts) >= 2:
|
||||||
|
owner = parts[-2]
|
||||||
|
repo = parts[-1]
|
||||||
|
|
||||||
|
# Get repository files via GitHub API
|
||||||
|
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents"
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Add timeout to prevent hanging
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
async with session.get(api_url, timeout=timeout) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
contents = await response.json()
|
||||||
|
|
||||||
|
# Look for common exploit files
|
||||||
|
target_files = [
|
||||||
|
'exploit.py', 'poc.py', 'exploit.c', 'exploit.cpp',
|
||||||
|
'exploit.java', 'exploit.rb', 'exploit.php',
|
||||||
|
'exploit.js', 'exploit.sh', 'exploit.ps1',
|
||||||
|
'README.md', 'main.py', 'index.js'
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_info in contents:
|
||||||
|
if file_info.get('type') == 'file':
|
||||||
|
filename = file_info.get('name', '').lower()
|
||||||
|
|
||||||
|
# Check if this is a target file
|
||||||
|
if any(target in filename for target in target_files):
|
||||||
|
file_url = file_info.get('download_url')
|
||||||
|
if file_url:
|
||||||
|
async with session.get(file_url, timeout=timeout) as file_response:
|
||||||
|
if file_response.status == 200:
|
||||||
|
content = await file_response.text()
|
||||||
|
# Limit content size
|
||||||
|
if len(content) > 10000:
|
||||||
|
content = content[:10000] + "\n... [content truncated]"
|
||||||
|
return content
|
||||||
|
|
||||||
|
# If no specific exploit file found, return description/README
|
||||||
|
for file_info in contents:
|
||||||
|
if file_info.get('type') == 'file':
|
||||||
|
filename = file_info.get('name', '').lower()
|
||||||
|
if 'readme' in filename:
|
||||||
|
file_url = file_info.get('download_url')
|
||||||
|
if file_url:
|
||||||
|
async with session.get(file_url, timeout=timeout) as file_response:
|
||||||
|
if file_response.status == 200:
|
||||||
|
content = await file_response.text()
|
||||||
|
return content[:5000] # Smaller limit for README
|
||||||
|
|
||||||
|
# Fallback to description and metadata
|
||||||
|
description = poc.get('description', '')
|
||||||
|
if description:
|
||||||
|
return f"Repository Description: {description}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting PoC content: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_log_source_from_content(self, rule_content: str) -> str:
|
||||||
|
"""Extract log source from the generated rule content"""
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
parsed = yaml.safe_load(rule_content)
|
||||||
|
logsource = parsed.get('logsource', {})
|
||||||
|
|
||||||
|
category = logsource.get('category', '')
|
||||||
|
product = logsource.get('product', '')
|
||||||
|
|
||||||
|
if category:
|
||||||
|
return category
|
||||||
|
elif product:
|
||||||
|
return product
|
||||||
|
else:
|
||||||
|
return 'generic'
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return 'generic'
|
||||||
|
|
||||||
async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]:
|
async def _select_template(self, cve, best_poc: Optional[dict]) -> Optional[object]:
|
||||||
"""Select the most appropriate template based on CVE and PoC analysis"""
|
"""Select the most appropriate template based on CVE and PoC analysis"""
|
||||||
from main import RuleTemplate
|
from main import RuleTemplate
|
||||||
|
|
398
backend/llm_client.py
Normal file
398
backend/llm_client.py
Normal file
|
@ -0,0 +1,398 @@
|
||||||
|
"""
|
||||||
|
LangChain-based LLM client for enhanced SIGMA rule generation.
|
||||||
|
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_community.llms import Ollama
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
|
||||||
|
|
||||||
|
SUPPORTED_PROVIDERS = {
|
||||||
|
'openai': {
|
||||||
|
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
|
||||||
|
'env_key': 'OPENAI_API_KEY',
|
||||||
|
'default_model': 'gpt-4o-mini'
|
||||||
|
},
|
||||||
|
'anthropic': {
|
||||||
|
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
|
||||||
|
'env_key': 'ANTHROPIC_API_KEY',
|
||||||
|
'default_model': 'claude-3-5-sonnet-20241022'
|
||||||
|
},
|
||||||
|
'ollama': {
|
||||||
|
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
|
||||||
|
'env_key': 'OLLAMA_BASE_URL',
|
||||||
|
'default_model': 'llama3.2'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, provider: str = None, model: str = None):
|
||||||
|
"""Initialize LLM client with specified provider and model."""
|
||||||
|
self.provider = provider or self._detect_provider()
|
||||||
|
self.model = model or self._get_default_model(self.provider)
|
||||||
|
self.llm = None
|
||||||
|
self.output_parser = StrOutputParser()
|
||||||
|
|
||||||
|
self._initialize_llm()
|
||||||
|
|
||||||
|
def _detect_provider(self) -> str:
|
||||||
|
"""Auto-detect available LLM provider based on environment variables."""
|
||||||
|
# Check for API keys in order of preference
|
||||||
|
if os.getenv('ANTHROPIC_API_KEY'):
|
||||||
|
return 'anthropic'
|
||||||
|
elif os.getenv('OPENAI_API_KEY'):
|
||||||
|
return 'openai'
|
||||||
|
elif os.getenv('OLLAMA_BASE_URL'):
|
||||||
|
return 'ollama'
|
||||||
|
else:
|
||||||
|
# Default to OpenAI if no keys found
|
||||||
|
return 'openai'
|
||||||
|
|
||||||
|
def _get_default_model(self, provider: str) -> str:
|
||||||
|
"""Get default model for the specified provider."""
|
||||||
|
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
|
||||||
|
|
||||||
|
def _initialize_llm(self):
|
||||||
|
"""Initialize the LLM based on provider and model."""
|
||||||
|
try:
|
||||||
|
if self.provider == 'openai':
|
||||||
|
api_key = os.getenv('OPENAI_API_KEY')
|
||||||
|
if not api_key:
|
||||||
|
logger.warning("OpenAI API key not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.llm = ChatOpenAI(
|
||||||
|
model=self.model,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=2000
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.provider == 'anthropic':
|
||||||
|
api_key = os.getenv('ANTHROPIC_API_KEY')
|
||||||
|
if not api_key:
|
||||||
|
logger.warning("Anthropic API key not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.llm = ChatAnthropic(
|
||||||
|
model=self.model,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=2000
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.provider == 'ollama':
|
||||||
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||||
|
|
||||||
|
self.llm = Ollama(
|
||||||
|
model=self.model,
|
||||||
|
base_url=base_url,
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.llm:
|
||||||
|
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing LLM client: {e}")
|
||||||
|
self.llm = None
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if LLM client is available and configured."""
|
||||||
|
return self.llm is not None
|
||||||
|
|
||||||
|
def get_provider_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get information about the current provider and configuration."""
|
||||||
|
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
||||||
|
return {
|
||||||
|
'provider': self.provider,
|
||||||
|
'model': self.model,
|
||||||
|
'available': self.is_available(),
|
||||||
|
'supported_models': provider_info.get('models', []),
|
||||||
|
'env_key': provider_info.get('env_key', ''),
|
||||||
|
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate_sigma_rule(self,
|
||||||
|
cve_id: str,
|
||||||
|
poc_content: str,
|
||||||
|
cve_description: str,
|
||||||
|
existing_rule: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Generate or enhance a SIGMA rule using the configured LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cve_id: CVE identifier
|
||||||
|
poc_content: Proof-of-concept code content from GitHub
|
||||||
|
cve_description: CVE description from NVD
|
||||||
|
existing_rule: Optional existing SIGMA rule to enhance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated SIGMA rule YAML content or None if failed
|
||||||
|
"""
|
||||||
|
if not self.is_available():
|
||||||
|
logger.warning("LLM client not available")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the prompt template
|
||||||
|
prompt = self._build_sigma_generation_prompt(
|
||||||
|
cve_id, poc_content, cve_description, existing_rule
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the chain
|
||||||
|
chain = prompt | self.llm | self.output_parser
|
||||||
|
|
||||||
|
# Generate the response
|
||||||
|
response = await chain.ainvoke({
|
||||||
|
"cve_id": cve_id,
|
||||||
|
"poc_content": poc_content[:4000], # Truncate if too long
|
||||||
|
"cve_description": cve_description,
|
||||||
|
"existing_rule": existing_rule or "None"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Extract the SIGMA rule from response
|
||||||
|
sigma_rule = self._extract_sigma_rule(response)
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
|
||||||
|
return sigma_rule
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_sigma_generation_prompt(self,
|
||||||
|
cve_id: str,
|
||||||
|
poc_content: str,
|
||||||
|
cve_description: str,
|
||||||
|
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
|
||||||
|
"""Build the prompt template for SIGMA rule generation."""
|
||||||
|
|
||||||
|
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules.
|
||||||
|
|
||||||
|
**Your Task:**
|
||||||
|
1. Analyze the exploit code to identify:
|
||||||
|
- Process execution patterns
|
||||||
|
- File system activities
|
||||||
|
- Network connections
|
||||||
|
- Registry modifications
|
||||||
|
- Command line arguments
|
||||||
|
- Suspicious behaviors
|
||||||
|
|
||||||
|
2. Create a SIGMA rule that:
|
||||||
|
- Follows proper SIGMA syntax (YAML format)
|
||||||
|
- Includes appropriate detection logic
|
||||||
|
- Has relevant metadata (title, description, author, date, references)
|
||||||
|
- Uses correct field names for the target log source
|
||||||
|
- Includes proper condition logic
|
||||||
|
- Maps to relevant MITRE ATT&CK techniques when applicable
|
||||||
|
|
||||||
|
3. Focus on detection patterns that would catch this specific exploit in action
|
||||||
|
|
||||||
|
**Important Requirements:**
|
||||||
|
- Output ONLY the SIGMA rule in valid YAML format
|
||||||
|
- Do not include explanations or comments outside the YAML
|
||||||
|
- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition
|
||||||
|
- Make the rule specific enough to detect the exploit but not too narrow to miss variants
|
||||||
|
- Include relevant tags and MITRE ATT&CK technique mappings"""
|
||||||
|
|
||||||
|
if existing_rule:
|
||||||
|
user_template = """**CVE Information:**
|
||||||
|
- CVE ID: {cve_id}
|
||||||
|
- Description: {cve_description}
|
||||||
|
|
||||||
|
**Proof-of-Concept Code:**
|
||||||
|
```
|
||||||
|
{poc_content}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Existing SIGMA Rule (to enhance):**
|
||||||
|
```yaml
|
||||||
|
{existing_rule}
|
||||||
|
```
|
||||||
|
|
||||||
|
Please enhance the existing rule with insights from the PoC code analysis."""
|
||||||
|
else:
|
||||||
|
user_template = """**CVE Information:**
|
||||||
|
- CVE ID: {cve_id}
|
||||||
|
- Description: {cve_description}
|
||||||
|
|
||||||
|
**Proof-of-Concept Code:**
|
||||||
|
```
|
||||||
|
{poc_content}
|
||||||
|
```
|
||||||
|
|
||||||
|
Please create a new SIGMA rule based on the PoC code analysis."""
|
||||||
|
|
||||||
|
return ChatPromptTemplate.from_messages([
|
||||||
|
SystemMessage(content=system_message),
|
||||||
|
HumanMessage(content=user_template)
|
||||||
|
])
|
||||||
|
|
||||||
|
def _extract_sigma_rule(self, response_text: str) -> str:
|
||||||
|
"""Extract SIGMA rule YAML from LLM response."""
|
||||||
|
# Look for YAML content in the response
|
||||||
|
lines = response_text.split('\n')
|
||||||
|
yaml_lines = []
|
||||||
|
in_yaml = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().startswith('```yaml') or line.strip().startswith('```'):
|
||||||
|
in_yaml = True
|
||||||
|
continue
|
||||||
|
elif line.strip() == '```' and in_yaml:
|
||||||
|
break
|
||||||
|
elif in_yaml or line.strip().startswith('title:'):
|
||||||
|
yaml_lines.append(line)
|
||||||
|
in_yaml = True
|
||||||
|
|
||||||
|
if not yaml_lines:
|
||||||
|
# If no YAML block found, return the whole response
|
||||||
|
return response_text.strip()
|
||||||
|
|
||||||
|
return '\n'.join(yaml_lines).strip()
|
||||||
|
|
||||||
|
async def enhance_existing_rule(self,
|
||||||
|
existing_rule: str,
|
||||||
|
poc_content: str,
|
||||||
|
cve_id: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Enhance an existing SIGMA rule with PoC analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_rule: Existing SIGMA rule YAML
|
||||||
|
poc_content: PoC code content
|
||||||
|
cve_id: CVE identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced SIGMA rule or None if failed
|
||||||
|
"""
|
||||||
|
if not self.is_available():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
|
||||||
|
|
||||||
|
**Task:** Enhance the existing rule by:
|
||||||
|
1. Adding more specific detection patterns found in the PoC
|
||||||
|
2. Improving the condition logic
|
||||||
|
3. Adding relevant tags or MITRE ATT&CK mappings
|
||||||
|
4. Keeping the rule structure intact but making it more effective
|
||||||
|
|
||||||
|
Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
||||||
|
|
||||||
|
user_template = """**CVE ID:** {cve_id}
|
||||||
|
|
||||||
|
**PoC Code:**
|
||||||
|
```
|
||||||
|
{poc_content}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Existing SIGMA Rule:**
|
||||||
|
```yaml
|
||||||
|
{existing_rule}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
|
SystemMessage(content=system_message),
|
||||||
|
HumanMessage(content=user_template)
|
||||||
|
])
|
||||||
|
|
||||||
|
chain = prompt | self.llm | self.output_parser
|
||||||
|
|
||||||
|
response = await chain.ainvoke({
|
||||||
|
"cve_id": cve_id,
|
||||||
|
"poc_content": poc_content[:3000],
|
||||||
|
"existing_rule": existing_rule
|
||||||
|
})
|
||||||
|
|
||||||
|
enhanced_rule = self._extract_sigma_rule(response)
|
||||||
|
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
|
||||||
|
return enhanced_rule
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_sigma_rule(self, rule_content: str) -> bool:
|
||||||
|
"""Validate that the generated rule is syntactically correct SIGMA."""
|
||||||
|
try:
|
||||||
|
# Parse as YAML
|
||||||
|
parsed = yaml.safe_load(rule_content)
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
required_fields = ['title', 'id', 'description', 'logsource', 'detection']
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in parsed:
|
||||||
|
logger.warning(f"Missing required field: {field}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check detection structure
|
||||||
|
detection = parsed.get('detection', {})
|
||||||
|
if not isinstance(detection, dict):
|
||||||
|
logger.warning("Detection field must be a dictionary")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Should have at least one selection and a condition
|
||||||
|
if 'condition' not in detection:
|
||||||
|
logger.warning("Detection must have a condition")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check logsource structure
|
||||||
|
logsource = parsed.get('logsource', {})
|
||||||
|
if not isinstance(logsource, dict):
|
||||||
|
logger.warning("Logsource field must be a dictionary")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("SIGMA rule validation passed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.warning(f"YAML parsing error: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Rule validation error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_providers(cls) -> List[Dict[str, Any]]:
|
||||||
|
"""Get list of available LLM providers and their configuration status."""
|
||||||
|
providers = []
|
||||||
|
|
||||||
|
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
|
||||||
|
env_key = provider_info.get('env_key', '')
|
||||||
|
api_key_configured = bool(os.getenv(env_key))
|
||||||
|
|
||||||
|
providers.append({
|
||||||
|
'name': provider_name,
|
||||||
|
'models': provider_info.get('models', []),
|
||||||
|
'default_model': provider_info.get('default_model', ''),
|
||||||
|
'env_key': env_key,
|
||||||
|
'api_key_configured': api_key_configured,
|
||||||
|
'available': api_key_configured or provider_name == 'ollama'
|
||||||
|
})
|
||||||
|
|
||||||
|
return providers
|
||||||
|
|
||||||
|
def switch_provider(self, provider: str, model: str = None):
|
||||||
|
"""Switch to a different LLM provider and model."""
|
||||||
|
if provider not in self.SUPPORTED_PROVIDERS:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model or self._get_default_model(provider)
|
||||||
|
self._initialize_llm()
|
||||||
|
|
||||||
|
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
165
backend/main.py
165
backend/main.py
|
@ -1414,6 +1414,171 @@ async def regenerate_sigma_rules(background_tasks: BackgroundTasks,
|
||||||
"force": request.force
|
"force": request.force
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@app.post("/api/llm-enhanced-rules")
|
||||||
|
async def generate_llm_enhanced_rules(request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||||
|
"""Generate SIGMA rules using LLM API for enhanced analysis"""
|
||||||
|
|
||||||
|
# Parse request parameters
|
||||||
|
cve_id = request.get('cve_id')
|
||||||
|
force = request.get('force', False)
|
||||||
|
llm_provider = request.get('provider', os.getenv('LLM_PROVIDER'))
|
||||||
|
llm_model = request.get('model', os.getenv('LLM_MODEL'))
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if cve_id and not re.match(r'^CVE-\d{4}-\d{4,}$', cve_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid CVE ID format")
|
||||||
|
|
||||||
|
async def llm_generation_task():
|
||||||
|
"""Background task for LLM-enhanced rule generation"""
|
||||||
|
try:
|
||||||
|
from enhanced_sigma_generator import EnhancedSigmaGenerator
|
||||||
|
|
||||||
|
generator = EnhancedSigmaGenerator(db, llm_provider, llm_model)
|
||||||
|
|
||||||
|
# Process specific CVE or all CVEs with PoC data
|
||||||
|
if cve_id:
|
||||||
|
cve = db.query(CVE).filter(CVE.cve_id == cve_id).first()
|
||||||
|
if not cve:
|
||||||
|
logger.error(f"CVE {cve_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
cves_to_process = [cve]
|
||||||
|
else:
|
||||||
|
# Process CVEs with PoC data that either have no rules or force update
|
||||||
|
query = db.query(CVE).filter(CVE.poc_count > 0)
|
||||||
|
|
||||||
|
if not force:
|
||||||
|
# Only process CVEs without existing LLM-generated rules
|
||||||
|
existing_llm_rules = db.query(SigmaRule).filter(
|
||||||
|
SigmaRule.detection_type.like('llm_%')
|
||||||
|
).all()
|
||||||
|
existing_cve_ids = {rule.cve_id for rule in existing_llm_rules}
|
||||||
|
cves_to_process = [cve for cve in query.all() if cve.cve_id not in existing_cve_ids]
|
||||||
|
else:
|
||||||
|
cves_to_process = query.all()
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(cves_to_process)} CVEs for LLM-enhanced rule generation using {llm_provider}")
|
||||||
|
|
||||||
|
rules_generated = 0
|
||||||
|
rules_updated = 0
|
||||||
|
failures = 0
|
||||||
|
|
||||||
|
for cve in cves_to_process:
|
||||||
|
try:
|
||||||
|
# Check if CVE has sufficient PoC data
|
||||||
|
if not cve.poc_data or not cve.poc_count:
|
||||||
|
logger.debug(f"Skipping {cve.cve_id} - no PoC data")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Generate LLM-enhanced rule
|
||||||
|
result = await generator.generate_enhanced_rule(cve, use_llm=True)
|
||||||
|
|
||||||
|
if result.get('success'):
|
||||||
|
if result.get('updated'):
|
||||||
|
rules_updated += 1
|
||||||
|
else:
|
||||||
|
rules_generated += 1
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated LLM-enhanced rule for {cve.cve_id}")
|
||||||
|
else:
|
||||||
|
failures += 1
|
||||||
|
logger.warning(f"Failed to generate LLM-enhanced rule for {cve.cve_id}: {result.get('error')}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failures += 1
|
||||||
|
logger.error(f"Error generating LLM-enhanced rule for {cve.cve_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"LLM-enhanced rule generation completed: {rules_generated} new, {rules_updated} updated, {failures} failures")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM-enhanced rule generation failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
background_tasks.add_task(llm_generation_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "LLM-enhanced SIGMA rule generation started",
|
||||||
|
"status": "started",
|
||||||
|
"cve_id": cve_id,
|
||||||
|
"force": force,
|
||||||
|
"provider": llm_provider,
|
||||||
|
"model": llm_model,
|
||||||
|
"note": "Requires appropriate LLM API key to be set"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/api/llm-status")
|
||||||
|
async def get_llm_status():
|
||||||
|
"""Check LLM API availability status"""
|
||||||
|
try:
|
||||||
|
from llm_client import LLMClient
|
||||||
|
|
||||||
|
# Get current provider configuration
|
||||||
|
provider = os.getenv('LLM_PROVIDER')
|
||||||
|
model = os.getenv('LLM_MODEL')
|
||||||
|
|
||||||
|
client = LLMClient(provider=provider, model=model)
|
||||||
|
provider_info = client.get_provider_info()
|
||||||
|
|
||||||
|
# Get all available providers
|
||||||
|
all_providers = LLMClient.get_available_providers()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"current_provider": provider_info,
|
||||||
|
"available_providers": all_providers,
|
||||||
|
"status": "ready" if client.is_available() else "unavailable"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking LLM status: {e}")
|
||||||
|
return {
|
||||||
|
"current_provider": {"provider": "unknown", "available": False},
|
||||||
|
"available_providers": [],
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/api/llm-switch")
|
||||||
|
async def switch_llm_provider(request: dict):
|
||||||
|
"""Switch LLM provider and model"""
|
||||||
|
try:
|
||||||
|
from llm_client import LLMClient
|
||||||
|
|
||||||
|
provider = request.get('provider')
|
||||||
|
model = request.get('model')
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=400, detail="Provider is required")
|
||||||
|
|
||||||
|
# Validate provider
|
||||||
|
if provider not in LLMClient.SUPPORTED_PROVIDERS:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
# Test the new configuration
|
||||||
|
client = LLMClient(provider=provider, model=model)
|
||||||
|
|
||||||
|
if not client.is_available():
|
||||||
|
raise HTTPException(status_code=400, detail=f"Provider {provider} is not available or not configured")
|
||||||
|
|
||||||
|
# Update environment variables (note: this only affects the current session)
|
||||||
|
os.environ['LLM_PROVIDER'] = provider
|
||||||
|
if model:
|
||||||
|
os.environ['LLM_MODEL'] = model
|
||||||
|
|
||||||
|
provider_info = client.get_provider_info()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Switched to {provider}",
|
||||||
|
"provider_info": provider_info,
|
||||||
|
"status": "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error switching LLM provider: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post("/api/cancel-job/{job_id}")
|
@app.post("/api/cancel-job/{job_id}")
|
||||||
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
async def cancel_job(job_id: str, db: Session = Depends(get_db)):
|
||||||
"""Cancel a running job"""
|
"""Cancel a running job"""
|
||||||
|
|
|
@ -15,3 +15,10 @@ lxml==4.9.3
|
||||||
aiohttp==3.9.1
|
aiohttp==3.9.1
|
||||||
aiofiles
|
aiofiles
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
|
langchain==0.2.0
|
||||||
|
langchain-openai==0.1.17
|
||||||
|
langchain-anthropic==0.1.15
|
||||||
|
langchain-community==0.2.0
|
||||||
|
langchain-core>=0.2.20
|
||||||
|
openai>=1.32.0
|
||||||
|
anthropic==0.40.0
|
||||||
|
|
|
@ -21,6 +21,7 @@ function App() {
|
||||||
const [gitHubPocStats, setGitHubPocStats] = useState({});
|
const [gitHubPocStats, setGitHubPocStats] = useState({});
|
||||||
const [bulkProcessing, setBulkProcessing] = useState(false);
|
const [bulkProcessing, setBulkProcessing] = useState(false);
|
||||||
const [hasRunningJobs, setHasRunningJobs] = useState(false);
|
const [hasRunningJobs, setHasRunningJobs] = useState(false);
|
||||||
|
const [llmStatus, setLlmStatus] = useState({});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
fetchData();
|
fetchData();
|
||||||
|
@ -29,14 +30,15 @@ function App() {
|
||||||
const fetchData = async () => {
|
const fetchData = async () => {
|
||||||
try {
|
try {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
const [cvesRes, rulesRes, statsRes, bulkJobsRes, bulkStatusRes, pocStatsRes, githubPocStatsRes] = await Promise.all([
|
const [cvesRes, rulesRes, statsRes, bulkJobsRes, bulkStatusRes, pocStatsRes, githubPocStatsRes, llmStatusRes] = await Promise.all([
|
||||||
axios.get(`${API_BASE_URL}/api/cves`),
|
axios.get(`${API_BASE_URL}/api/cves`),
|
||||||
axios.get(`${API_BASE_URL}/api/sigma-rules`),
|
axios.get(`${API_BASE_URL}/api/sigma-rules`),
|
||||||
axios.get(`${API_BASE_URL}/api/stats`),
|
axios.get(`${API_BASE_URL}/api/stats`),
|
||||||
axios.get(`${API_BASE_URL}/api/bulk-jobs`),
|
axios.get(`${API_BASE_URL}/api/bulk-jobs`),
|
||||||
axios.get(`${API_BASE_URL}/api/bulk-status`),
|
axios.get(`${API_BASE_URL}/api/bulk-status`),
|
||||||
axios.get(`${API_BASE_URL}/api/poc-stats`),
|
axios.get(`${API_BASE_URL}/api/poc-stats`),
|
||||||
axios.get(`${API_BASE_URL}/api/github-poc-stats`).catch(err => ({ data: {} }))
|
axios.get(`${API_BASE_URL}/api/github-poc-stats`).catch(err => ({ data: {} })),
|
||||||
|
axios.get(`${API_BASE_URL}/api/llm-status`).catch(err => ({ data: {} }))
|
||||||
]);
|
]);
|
||||||
|
|
||||||
setCves(cvesRes.data);
|
setCves(cvesRes.data);
|
||||||
|
@ -46,6 +48,7 @@ function App() {
|
||||||
setBulkStatus(bulkStatusRes.data);
|
setBulkStatus(bulkStatusRes.data);
|
||||||
setPocStats(pocStatsRes.data);
|
setPocStats(pocStatsRes.data);
|
||||||
setGitHubPocStats(githubPocStatsRes.data);
|
setGitHubPocStats(githubPocStatsRes.data);
|
||||||
|
setLlmStatus(llmStatusRes.data);
|
||||||
|
|
||||||
// Update running jobs state
|
// Update running jobs state
|
||||||
const runningJobs = bulkJobsRes.data.filter(job => job.status === 'running' || job.status === 'pending');
|
const runningJobs = bulkJobsRes.data.filter(job => job.status === 'running' || job.status === 'pending');
|
||||||
|
@ -166,6 +169,32 @@ function App() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const generateLlmRules = async (force = false) => {
|
||||||
|
try {
|
||||||
|
const response = await axios.post(`${API_BASE_URL}/api/llm-enhanced-rules`, {
|
||||||
|
force: force
|
||||||
|
});
|
||||||
|
console.log('LLM rule generation response:', response.data);
|
||||||
|
fetchData();
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error generating LLM-enhanced rules:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const switchLlmProvider = async (provider, model) => {
|
||||||
|
try {
|
||||||
|
const response = await axios.post(`${API_BASE_URL}/api/llm-switch`, {
|
||||||
|
provider: provider,
|
||||||
|
model: model
|
||||||
|
});
|
||||||
|
console.log('LLM provider switch response:', response.data);
|
||||||
|
fetchData(); // Refresh to get updated status
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error switching LLM provider:', error);
|
||||||
|
alert('Failed to switch LLM provider. Please check configuration.');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const getSeverityColor = (severity) => {
|
const getSeverityColor = (severity) => {
|
||||||
switch (severity?.toLowerCase()) {
|
switch (severity?.toLowerCase()) {
|
||||||
case 'critical': return 'bg-red-100 text-red-800';
|
case 'critical': return 'bg-red-100 text-red-800';
|
||||||
|
@ -197,6 +226,9 @@ function App() {
|
||||||
<p className="text-3xl font-bold text-green-600">{stats.total_sigma_rules || 0}</p>
|
<p className="text-3xl font-bold text-green-600">{stats.total_sigma_rules || 0}</p>
|
||||||
<p className="text-sm text-gray-500">Nomi-sec: {stats.nomi_sec_rules || 0}</p>
|
<p className="text-sm text-gray-500">Nomi-sec: {stats.nomi_sec_rules || 0}</p>
|
||||||
<p className="text-sm text-gray-500">GitHub PoCs: {gitHubPocStats.github_poc_rules || 0}</p>
|
<p className="text-sm text-gray-500">GitHub PoCs: {gitHubPocStats.github_poc_rules || 0}</p>
|
||||||
|
<p className={`text-sm ${llmStatus.status === 'ready' ? 'text-green-600' : 'text-red-500'}`}>
|
||||||
|
LLM: {llmStatus.current_provider?.provider || 'Not Available'}
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div className="bg-white p-6 rounded-lg shadow">
|
<div className="bg-white p-6 rounded-lg shadow">
|
||||||
<h3 className="text-lg font-medium text-gray-900">CVEs with PoCs</h3>
|
<h3 className="text-lg font-medium text-gray-900">CVEs with PoCs</h3>
|
||||||
|
@ -219,7 +251,7 @@ function App() {
|
||||||
{/* Bulk Processing Controls */}
|
{/* Bulk Processing Controls */}
|
||||||
<div className="bg-white rounded-lg shadow p-6">
|
<div className="bg-white rounded-lg shadow p-6">
|
||||||
<h2 className="text-xl font-bold text-gray-900 mb-4">Bulk Processing</h2>
|
<h2 className="text-xl font-bold text-gray-900 mb-4">Bulk Processing</h2>
|
||||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-5 gap-4">
|
||||||
<button
|
<button
|
||||||
onClick={() => startBulkSeed(2002)}
|
onClick={() => startBulkSeed(2002)}
|
||||||
disabled={hasRunningJobs}
|
disabled={hasRunningJobs}
|
||||||
|
@ -275,6 +307,64 @@ function App() {
|
||||||
>
|
>
|
||||||
{hasRunningJobs ? 'Processing...' : 'Regenerate Rules'}
|
{hasRunningJobs ? 'Processing...' : 'Regenerate Rules'}
|
||||||
</button>
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => generateLlmRules()}
|
||||||
|
disabled={hasRunningJobs || llmStatus.status !== 'ready'}
|
||||||
|
className={`px-4 py-2 rounded-md text-white ${
|
||||||
|
hasRunningJobs || llmStatus.status !== 'ready'
|
||||||
|
? 'bg-gray-400 cursor-not-allowed'
|
||||||
|
: 'bg-violet-600 hover:bg-violet-700'
|
||||||
|
}`}
|
||||||
|
title={llmStatus.status !== 'ready' ? 'LLM not configured' : ''}
|
||||||
|
>
|
||||||
|
{hasRunningJobs ? 'Processing...' : 'Generate LLM Rules'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* LLM Configuration */}
|
||||||
|
<div className="bg-white rounded-lg shadow p-6">
|
||||||
|
<h2 className="text-xl font-bold text-gray-900 mb-4">LLM Configuration</h2>
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||||
|
<div>
|
||||||
|
<h3 className="text-lg font-medium text-gray-900 mb-2">Current Provider</h3>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<p className="text-sm text-gray-600">
|
||||||
|
Provider: <span className="font-medium">{llmStatus.current_provider?.provider || 'Not configured'}</span>
|
||||||
|
</p>
|
||||||
|
<p className="text-sm text-gray-600">
|
||||||
|
Model: <span className="font-medium">{llmStatus.current_provider?.model || 'Not configured'}</span>
|
||||||
|
</p>
|
||||||
|
<p className={`text-sm ${llmStatus.status === 'ready' ? 'text-green-600' : 'text-red-500'}`}>
|
||||||
|
Status: <span className="font-medium">{llmStatus.status || 'Unknown'}</span>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<h3 className="text-lg font-medium text-gray-900 mb-2">Available Providers</h3>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{llmStatus.available_providers?.map(provider => (
|
||||||
|
<div key={provider.name} className="flex items-center justify-between p-2 bg-gray-50 rounded">
|
||||||
|
<div>
|
||||||
|
<span className="font-medium">{provider.name}</span>
|
||||||
|
<span className={`ml-2 text-xs px-2 py-1 rounded ${
|
||||||
|
provider.available ? 'bg-green-100 text-green-800' : 'bg-red-100 text-red-800'
|
||||||
|
}`}>
|
||||||
|
{provider.available ? 'Available' : 'Not configured'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
{provider.available && provider.name !== llmStatus.current_provider?.provider && (
|
||||||
|
<button
|
||||||
|
onClick={() => switchLlmProvider(provider.name, provider.default_model)}
|
||||||
|
className="text-xs bg-blue-600 hover:bg-blue-700 text-white px-2 py-1 rounded"
|
||||||
|
>
|
||||||
|
Switch
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue