auto_sigma_rule_generator/backend/llm_client.py

484 lines
No EOL
18 KiB
Python

"""
LangChain-based LLM client for enhanced SIGMA rule generation.
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
"""
import os
import logging
from typing import Optional, Dict, Any, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
import yaml
logger = logging.getLogger(__name__)
class LLMClient:
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
SUPPORTED_PROVIDERS = {
'openai': {
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
'env_key': 'OPENAI_API_KEY',
'default_model': 'gpt-4o-mini'
},
'anthropic': {
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
'env_key': 'ANTHROPIC_API_KEY',
'default_model': 'claude-3-5-sonnet-20241022'
},
'ollama': {
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
'env_key': 'OLLAMA_BASE_URL',
'default_model': 'llama3.2'
}
}
def __init__(self, provider: str = None, model: str = None):
"""Initialize LLM client with specified provider and model."""
self.provider = provider or self._detect_provider()
self.model = model or self._get_default_model(self.provider)
self.llm = None
self.output_parser = StrOutputParser()
self._initialize_llm()
def _detect_provider(self) -> str:
"""Auto-detect available LLM provider based on environment variables."""
# Check for API keys in order of preference
if os.getenv('ANTHROPIC_API_KEY'):
return 'anthropic'
elif os.getenv('OPENAI_API_KEY'):
return 'openai'
elif os.getenv('OLLAMA_BASE_URL'):
return 'ollama'
else:
# Default to OpenAI if no keys found
return 'openai'
def _get_default_model(self, provider: str) -> str:
"""Get default model for the specified provider."""
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
def _initialize_llm(self):
"""Initialize the LLM based on provider and model."""
try:
if self.provider == 'openai':
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
logger.warning("OpenAI API key not found")
return
self.llm = ChatOpenAI(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'anthropic':
api_key = os.getenv('ANTHROPIC_API_KEY')
if not api_key:
logger.warning("Anthropic API key not found")
return
self.llm = ChatAnthropic(
model=self.model,
api_key=api_key,
temperature=0.1,
max_tokens=2000
)
elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
# Check if model is available, if not try to pull it
if not self._check_ollama_model_available(base_url, self.model):
logger.info(f"Model {self.model} not found, attempting to pull...")
if self._pull_ollama_model(base_url, self.model):
logger.info(f"Successfully pulled model {self.model}")
else:
logger.error(f"Failed to pull model {self.model}")
return
self.llm = Ollama(
model=self.model,
base_url=base_url,
temperature=0.1
)
if self.llm:
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
else:
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
except Exception as e:
logger.error(f"Error initializing LLM client: {e}")
self.llm = None
def is_available(self) -> bool:
"""Check if LLM client is available and configured."""
return self.llm is not None
def get_provider_info(self) -> Dict[str, Any]:
"""Get information about the current provider and configuration."""
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
# For Ollama, get actually available models
available_models = provider_info.get('models', [])
if self.provider == 'ollama':
ollama_models = self._get_ollama_available_models()
if ollama_models:
available_models = ollama_models
return {
'provider': self.provider,
'model': self.model,
'available': self.is_available(),
'supported_models': provider_info.get('models', []),
'available_models': available_models,
'env_key': provider_info.get('env_key', ''),
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
}
async def generate_sigma_rule(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> Optional[str]:
"""
Generate or enhance a SIGMA rule using the configured LLM.
Args:
cve_id: CVE identifier
poc_content: Proof-of-concept code content from GitHub
cve_description: CVE description from NVD
existing_rule: Optional existing SIGMA rule to enhance
Returns:
Generated SIGMA rule YAML content or None if failed
"""
if not self.is_available():
logger.warning("LLM client not available")
return None
try:
# Create the prompt template
prompt = self._build_sigma_generation_prompt(
cve_id, poc_content, cve_description, existing_rule
)
# Create the chain
chain = prompt | self.llm | self.output_parser
# Generate the response
response = await chain.ainvoke({
"cve_id": cve_id,
"poc_content": poc_content[:4000], # Truncate if too long
"cve_description": cve_description,
"existing_rule": existing_rule or "None"
})
# Extract the SIGMA rule from response
sigma_rule = self._extract_sigma_rule(response)
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
return sigma_rule
except Exception as e:
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
return None
def _build_sigma_generation_prompt(self,
cve_id: str,
poc_content: str,
cve_description: str,
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
"""Build the prompt template for SIGMA rule generation."""
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation for threat detection. Your goal is to analyze exploit code from GitHub PoC repositories and create syntactically correct SIGMA rules.
**Your Task:**
1. Analyze the exploit code to identify:
- Process execution patterns
- File system activities
- Network connections
- Registry modifications
- Command line arguments
- Suspicious behaviors
2. Create a SIGMA rule that:
- Follows proper SIGMA syntax (YAML format)
- Includes appropriate detection logic
- Has relevant metadata (title, description, author, date, references)
- Uses correct field names for the target log source
- Includes proper condition logic
- Maps to relevant MITRE ATT&CK techniques when applicable
3. Focus on detection patterns that would catch this specific exploit in action
**Important Requirements:**
- Output ONLY the SIGMA rule in valid YAML format
- Do not include explanations or comments outside the YAML
- Use proper SIGMA rule structure with title, id, status, description, references, author, date, logsource, detection, and condition
- Make the rule specific enough to detect the exploit but not too narrow to miss variants
- Include relevant tags and MITRE ATT&CK technique mappings"""
if existing_rule:
user_template = """**CVE Information:**
- CVE ID: {cve_id}
- Description: {cve_description}
**Proof-of-Concept Code:**
```
{poc_content}
```
**Existing SIGMA Rule (to enhance):**
```yaml
{existing_rule}
```
Please enhance the existing rule with insights from the PoC code analysis."""
else:
user_template = """**CVE Information:**
- CVE ID: {cve_id}
- Description: {cve_description}
**Proof-of-Concept Code:**
```
{poc_content}
```
Please create a new SIGMA rule based on the PoC code analysis."""
return ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
def _extract_sigma_rule(self, response_text: str) -> str:
"""Extract SIGMA rule YAML from LLM response."""
# Look for YAML content in the response
lines = response_text.split('\n')
yaml_lines = []
in_yaml = False
for line in lines:
if line.strip().startswith('```yaml') or line.strip().startswith('```'):
in_yaml = True
continue
elif line.strip() == '```' and in_yaml:
break
elif in_yaml or line.strip().startswith('title:'):
yaml_lines.append(line)
in_yaml = True
if not yaml_lines:
# If no YAML block found, return the whole response
return response_text.strip()
return '\n'.join(yaml_lines).strip()
async def enhance_existing_rule(self,
existing_rule: str,
poc_content: str,
cve_id: str) -> Optional[str]:
"""
Enhance an existing SIGMA rule with PoC analysis.
Args:
existing_rule: Existing SIGMA rule YAML
poc_content: PoC code content
cve_id: CVE identifier
Returns:
Enhanced SIGMA rule or None if failed
"""
if not self.is_available():
return None
try:
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
**Task:** Enhance the existing rule by:
1. Adding more specific detection patterns found in the PoC
2. Improving the condition logic
3. Adding relevant tags or MITRE ATT&CK mappings
4. Keeping the rule structure intact but making it more effective
Output ONLY the enhanced SIGMA rule in valid YAML format."""
user_template = """**CVE ID:** {cve_id}
**PoC Code:**
```
{poc_content}
```
**Existing SIGMA Rule:**
```yaml
{existing_rule}
```"""
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=system_message),
HumanMessage(content=user_template)
])
chain = prompt | self.llm | self.output_parser
response = await chain.ainvoke({
"cve_id": cve_id,
"poc_content": poc_content[:3000],
"existing_rule": existing_rule
})
enhanced_rule = self._extract_sigma_rule(response)
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
return enhanced_rule
except Exception as e:
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
return None
def validate_sigma_rule(self, rule_content: str) -> bool:
"""Validate that the generated rule is syntactically correct SIGMA."""
try:
# Parse as YAML
parsed = yaml.safe_load(rule_content)
# Check required fields
required_fields = ['title', 'id', 'description', 'logsource', 'detection']
for field in required_fields:
if field not in parsed:
logger.warning(f"Missing required field: {field}")
return False
# Check detection structure
detection = parsed.get('detection', {})
if not isinstance(detection, dict):
logger.warning("Detection field must be a dictionary")
return False
# Should have at least one selection and a condition
if 'condition' not in detection:
logger.warning("Detection must have a condition")
return False
# Check logsource structure
logsource = parsed.get('logsource', {})
if not isinstance(logsource, dict):
logger.warning("Logsource field must be a dictionary")
return False
logger.info("SIGMA rule validation passed")
return True
except yaml.YAMLError as e:
logger.warning(f"YAML parsing error: {e}")
return False
except Exception as e:
logger.warning(f"Rule validation error: {e}")
return False
@classmethod
def get_available_providers(cls) -> List[Dict[str, Any]]:
"""Get list of available LLM providers and their configuration status."""
providers = []
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
env_key = provider_info.get('env_key', '')
api_key_configured = bool(os.getenv(env_key))
providers.append({
'name': provider_name,
'models': provider_info.get('models', []),
'default_model': provider_info.get('default_model', ''),
'env_key': env_key,
'api_key_configured': api_key_configured,
'available': api_key_configured or provider_name == 'ollama'
})
return providers
def switch_provider(self, provider: str, model: str = None):
"""Switch to a different LLM provider and model."""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(f"Unsupported provider: {provider}")
self.provider = provider
self.model = model or self._get_default_model(provider)
self._initialize_llm()
logger.info(f"Switched to provider: {provider} with model: {self.model}")
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
"""Check if an Ollama model is available locally"""
try:
import requests
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
return True
return False
except Exception as e:
logger.error(f"Error checking Ollama models: {e}")
return False
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
import requests
import json
# Use the pull API endpoint
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=300, # 5 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
if data.get('status'):
logger.info(f"Ollama pull progress: {data.get('status')}")
if data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
return False
except Exception as e:
logger.error(f"Error pulling Ollama model {model}: {e}")
return False
def _get_ollama_available_models(self) -> List[str]:
"""Get list of available Ollama models"""
try:
import requests
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
return [m.get('name', '') for m in models if m.get('name')]
return []
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
return []