652 lines
No EOL
26 KiB
Python
652 lines
No EOL
26 KiB
Python
"""
|
|
LangChain-based LLM client for enhanced SIGMA rule generation.
|
|
Supports multiple LLM providers: OpenAI, Anthropic, and local models.
|
|
"""
|
|
import os
|
|
import logging
|
|
from typing import Optional, Dict, Any, List
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_community.llms import Ollama
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
import yaml
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LLMClient:
|
|
"""Multi-provider LLM client for SIGMA rule generation using LangChain."""
|
|
|
|
SUPPORTED_PROVIDERS = {
|
|
'openai': {
|
|
'models': ['gpt-4o', 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo'],
|
|
'env_key': 'OPENAI_API_KEY',
|
|
'default_model': 'gpt-4o-mini'
|
|
},
|
|
'anthropic': {
|
|
'models': ['claude-3-5-sonnet-20241022', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'],
|
|
'env_key': 'ANTHROPIC_API_KEY',
|
|
'default_model': 'claude-3-5-sonnet-20241022'
|
|
},
|
|
'ollama': {
|
|
'models': ['llama3.2', 'codellama', 'mistral', 'llama2'],
|
|
'env_key': 'OLLAMA_BASE_URL',
|
|
'default_model': 'llama3.2'
|
|
}
|
|
}
|
|
|
|
def __init__(self, provider: str = None, model: str = None):
|
|
"""Initialize LLM client with specified provider and model."""
|
|
self.provider = provider or self._detect_provider()
|
|
self.model = model or self._get_default_model(self.provider)
|
|
self.llm = None
|
|
self.output_parser = StrOutputParser()
|
|
|
|
self._initialize_llm()
|
|
|
|
def _detect_provider(self) -> str:
|
|
"""Auto-detect available LLM provider based on environment variables."""
|
|
# Check for API keys in order of preference
|
|
if os.getenv('ANTHROPIC_API_KEY'):
|
|
return 'anthropic'
|
|
elif os.getenv('OPENAI_API_KEY'):
|
|
return 'openai'
|
|
elif os.getenv('OLLAMA_BASE_URL'):
|
|
return 'ollama'
|
|
else:
|
|
# Default to OpenAI if no keys found
|
|
return 'openai'
|
|
|
|
def _get_default_model(self, provider: str) -> str:
|
|
"""Get default model for the specified provider."""
|
|
return self.SUPPORTED_PROVIDERS.get(provider, {}).get('default_model', 'gpt-4o-mini')
|
|
|
|
def _initialize_llm(self):
|
|
"""Initialize the LLM based on provider and model."""
|
|
try:
|
|
if self.provider == 'openai':
|
|
api_key = os.getenv('OPENAI_API_KEY')
|
|
if not api_key:
|
|
logger.warning("OpenAI API key not found")
|
|
return
|
|
|
|
self.llm = ChatOpenAI(
|
|
model=self.model,
|
|
api_key=api_key,
|
|
temperature=0.1,
|
|
max_tokens=2000
|
|
)
|
|
|
|
elif self.provider == 'anthropic':
|
|
api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
if not api_key:
|
|
logger.warning("Anthropic API key not found")
|
|
return
|
|
|
|
self.llm = ChatAnthropic(
|
|
model=self.model,
|
|
api_key=api_key,
|
|
temperature=0.1,
|
|
max_tokens=2000
|
|
)
|
|
|
|
elif self.provider == 'ollama':
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
|
|
|
# Check if model is available, if not try to pull it
|
|
if not self._check_ollama_model_available(base_url, self.model):
|
|
logger.info(f"Model {self.model} not found, attempting to pull...")
|
|
if self._pull_ollama_model(base_url, self.model):
|
|
logger.info(f"Successfully pulled model {self.model}")
|
|
else:
|
|
logger.error(f"Failed to pull model {self.model}")
|
|
return
|
|
|
|
self.llm = Ollama(
|
|
model=self.model,
|
|
base_url=base_url,
|
|
temperature=0.1
|
|
)
|
|
|
|
if self.llm:
|
|
logger.info(f"LLM client initialized: {self.provider} with model {self.model}")
|
|
else:
|
|
logger.error(f"Failed to initialize LLM client for provider: {self.provider}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing LLM client: {e}")
|
|
self.llm = None
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if LLM client is available and configured."""
|
|
return self.llm is not None
|
|
|
|
def get_provider_info(self) -> Dict[str, Any]:
|
|
"""Get information about the current provider and configuration."""
|
|
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
|
|
|
# For Ollama, get actually available models
|
|
available_models = provider_info.get('models', [])
|
|
if self.provider == 'ollama':
|
|
ollama_models = self._get_ollama_available_models()
|
|
if ollama_models:
|
|
available_models = ollama_models
|
|
|
|
return {
|
|
'provider': self.provider,
|
|
'model': self.model,
|
|
'available': self.is_available(),
|
|
'supported_models': provider_info.get('models', []),
|
|
'available_models': available_models,
|
|
'env_key': provider_info.get('env_key', ''),
|
|
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
|
|
}
|
|
|
|
async def generate_sigma_rule(self,
|
|
cve_id: str,
|
|
poc_content: str,
|
|
cve_description: str,
|
|
existing_rule: Optional[str] = None) -> Optional[str]:
|
|
"""
|
|
Generate or enhance a SIGMA rule using the configured LLM.
|
|
|
|
Args:
|
|
cve_id: CVE identifier
|
|
poc_content: Proof-of-concept code content from GitHub
|
|
cve_description: CVE description from NVD
|
|
existing_rule: Optional existing SIGMA rule to enhance
|
|
|
|
Returns:
|
|
Generated SIGMA rule YAML content or None if failed
|
|
"""
|
|
if not self.is_available():
|
|
logger.warning("LLM client not available")
|
|
return None
|
|
|
|
try:
|
|
# Create the prompt template
|
|
prompt = self._build_sigma_generation_prompt(
|
|
cve_id, poc_content, cve_description, existing_rule
|
|
)
|
|
|
|
# Create the chain
|
|
chain = prompt | self.llm | self.output_parser
|
|
|
|
# Debug: Log what we're sending to the LLM
|
|
input_data = {
|
|
"cve_id": cve_id,
|
|
"poc_content": poc_content[:4000], # Truncate if too long
|
|
"cve_description": cve_description,
|
|
"existing_rule": existing_rule or "None"
|
|
}
|
|
logger.info(f"Sending to LLM for {cve_id}: CVE={cve_id}, Description length={len(cve_description)}, PoC length={len(poc_content)}")
|
|
|
|
# Generate the response
|
|
response = await chain.ainvoke(input_data)
|
|
|
|
# Debug: Log raw LLM response
|
|
logger.info(f"Raw LLM response for {cve_id}: {response[:200]}...")
|
|
|
|
# Extract the SIGMA rule from response
|
|
sigma_rule = self._extract_sigma_rule(response)
|
|
|
|
# Post-process to ensure clean YAML
|
|
sigma_rule = self._post_process_sigma_rule(sigma_rule)
|
|
|
|
# Debug: Log final processed rule
|
|
logger.info(f"Final processed rule for {cve_id}: {sigma_rule[:200]}...")
|
|
|
|
logger.info(f"Successfully generated SIGMA rule for {cve_id} using {self.provider}")
|
|
return sigma_rule
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate SIGMA rule for {cve_id} using {self.provider}: {e}")
|
|
return None
|
|
|
|
def _build_sigma_generation_prompt(self,
|
|
cve_id: str,
|
|
poc_content: str,
|
|
cve_description: str,
|
|
existing_rule: Optional[str] = None) -> ChatPromptTemplate:
|
|
"""Build the prompt template for SIGMA rule generation."""
|
|
|
|
system_message = """You are a cybersecurity expert specializing in SIGMA rule creation following the official SIGMA specification.
|
|
|
|
**CRITICAL: You must follow the exact SIGMA specification format:**
|
|
|
|
1. **YAML Structure Requirements:**
|
|
- Use UTF-8 encoding with LF line breaks
|
|
- Indent with 4 spaces (no tabs)
|
|
- Use lowercase keys only
|
|
- Use single quotes for string values
|
|
- No quotes for numeric values
|
|
- Follow proper YAML syntax
|
|
|
|
2. **MANDATORY Fields (must include):**
|
|
- title: Brief description (max 256 chars)
|
|
- logsource: Log data source specification
|
|
- detection: Search identifiers and conditions
|
|
- condition: How detection elements combine
|
|
|
|
3. **RECOMMENDED Fields:**
|
|
- id: Unique UUID
|
|
- status: 'experimental' (for new rules)
|
|
- description: Detailed explanation
|
|
- author: 'AI Generated'
|
|
- date: Current date (YYYY/MM/DD)
|
|
- references: Array with CVE link
|
|
- tags: MITRE ATT&CK techniques
|
|
|
|
4. **Detection Structure:**
|
|
- Use selection blocks (selection, selection1, etc.)
|
|
- Condition references these selections
|
|
- Use proper field names (Image, CommandLine, ProcessName, etc.)
|
|
- Support wildcards (*) and value lists
|
|
|
|
**ABSOLUTE REQUIREMENTS:**
|
|
- Output ONLY valid YAML
|
|
- NO explanatory text before or after
|
|
- NO comments or instructions
|
|
- NO markdown formatting or code blocks
|
|
- NEVER repeat the input prompt or template
|
|
- NEVER include variables like {cve_id} or {poc_content}
|
|
- NO "Human:", "CVE ID:", "Description:" headers
|
|
- NO "Analyze this" or "Output only" text
|
|
- Start IMMEDIATELY with 'title:'
|
|
- End with the last YAML line only
|
|
- Ensure perfect YAML syntax
|
|
|
|
**STRUCTURE REQUIREMENTS:**
|
|
- title: Descriptive title that MUST include the exact CVE ID provided by the user
|
|
- id: Generate a unique UUID (not '12345678-1234-1234-1234-123456789012')
|
|
- status: experimental
|
|
- description: Specific description based on CVE and PoC analysis
|
|
- author: 'AI Generated'
|
|
- date: Current date (2025/01/11)
|
|
- references: Include the EXACT CVE URL with the CVE ID provided by the user
|
|
- tags: Relevant MITRE ATT&CK techniques based on PoC analysis
|
|
- logsource: Appropriate category based on exploit type
|
|
- detection: Specific indicators from PoC analysis (NOT generic examples)
|
|
- condition: Logic connecting the detection selections
|
|
|
|
**CRITICAL RULES:**
|
|
1. You MUST use the EXACT CVE ID provided in the user input - NEVER generate a different CVE ID
|
|
2. Analyze the provided CVE and PoC content to create SPECIFIC detection patterns
|
|
3. DO NOT hallucinate or invent CVE IDs from your training data
|
|
4. Use the CVE ID exactly as provided in the title and references"""
|
|
|
|
if existing_rule:
|
|
user_template = """CVE ID: {cve_id}
|
|
CVE Description: {cve_description}
|
|
|
|
PoC Code:
|
|
{poc_content}
|
|
|
|
Existing SIGMA Rule:
|
|
{existing_rule}
|
|
|
|
Enhance this rule with PoC insights. Output only valid SIGMA YAML starting with 'title:'."""
|
|
else:
|
|
user_template = """CREATE A SPECIFIC SIGMA RULE FOR THIS EXACT CVE:
|
|
|
|
**MANDATORY CVE ID TO USE: {cve_id}**
|
|
**CVE Description: {cve_description}**
|
|
|
|
**Proof-of-Concept Code Analysis:**
|
|
{poc_content}
|
|
|
|
**CRITICAL REQUIREMENTS:**
|
|
1. Use EXACTLY this CVE ID in the title: {cve_id}
|
|
2. Use EXACTLY this CVE URL in references: https://nvd.nist.gov/vuln/detail/{cve_id}
|
|
3. Analyze the CVE description to understand the vulnerability type
|
|
4. Extract specific indicators from the PoC code (files, processes, commands, network patterns)
|
|
5. Create detection logic based on the actual exploit behavior
|
|
6. Use relevant logsource category (process_creation, file_event, network_connection, etc.)
|
|
7. Include appropriate MITRE ATT&CK tags based on the exploit techniques
|
|
|
|
**IMPORTANT: You MUST use the exact CVE ID "{cve_id}" - do NOT generate a different CVE ID!**
|
|
|
|
Output ONLY valid SIGMA YAML starting with 'title:' that includes the exact CVE ID {cve_id}."""
|
|
|
|
return ChatPromptTemplate.from_messages([
|
|
SystemMessage(content=system_message),
|
|
HumanMessage(content=user_template)
|
|
])
|
|
|
|
def _extract_sigma_rule(self, response_text: str) -> str:
|
|
"""Extract and clean SIGMA rule YAML from LLM response."""
|
|
lines = response_text.split('\n')
|
|
yaml_lines = []
|
|
in_yaml_block = False
|
|
found_title = False
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
|
|
# Skip code block markers
|
|
if stripped.startswith('```'):
|
|
if stripped.startswith('```yaml'):
|
|
in_yaml_block = True
|
|
elif stripped == '```' and in_yaml_block:
|
|
break
|
|
continue
|
|
|
|
# Skip obvious non-YAML content
|
|
if not in_yaml_block and not found_title:
|
|
if not stripped.startswith('title:'):
|
|
# Skip explanatory text and prompt artifacts
|
|
skip_phrases = [
|
|
'please note', 'this rule', 'you should', 'analysis:',
|
|
'explanation:', 'based on', 'the following', 'here is',
|
|
'note that', 'important:', 'remember', 'this is a',
|
|
'make sure to', 'you can modify', 'adjust the',
|
|
'human:', 'cve id:', 'cve description:', 'poc code:',
|
|
'exploit code:', 'analyze this', 'create a', 'output only'
|
|
]
|
|
if any(phrase in stripped.lower() for phrase in skip_phrases):
|
|
continue
|
|
|
|
# Skip template variables and prompt artifacts
|
|
if '{' in stripped and '}' in stripped:
|
|
continue
|
|
|
|
# Skip lines that are clearly not YAML structure
|
|
if stripped and not ':' in stripped and len(stripped) > 20:
|
|
continue
|
|
|
|
# Start collecting when we find title or are in YAML block
|
|
if stripped.startswith('title:') or in_yaml_block:
|
|
found_title = True
|
|
in_yaml_block = True
|
|
|
|
# Skip explanatory comments
|
|
if stripped.startswith('#') and ('please' in stripped.lower() or 'note' in stripped.lower()):
|
|
continue
|
|
|
|
yaml_lines.append(line)
|
|
|
|
# Stop if we encounter obvious non-YAML after starting
|
|
elif found_title:
|
|
# Stop at explanatory text after the rule
|
|
stop_phrases = [
|
|
'please note', 'this rule should', 'make sure to',
|
|
'you can modify', 'adjust the', 'also, this is',
|
|
'based on the analysis', 'the rule above'
|
|
]
|
|
if any(phrase in stripped.lower() for phrase in stop_phrases):
|
|
break
|
|
|
|
# Stop at lines without colons that aren't indented (likely explanations)
|
|
if stripped and not stripped.startswith(' ') and ':' not in stripped and '-' not in stripped:
|
|
break
|
|
|
|
if not yaml_lines:
|
|
# Fallback: look for any line with YAML-like structure
|
|
for line in lines:
|
|
if ':' in line and not line.strip().startswith('#'):
|
|
yaml_lines.append(line)
|
|
|
|
return '\n'.join(yaml_lines).strip()
|
|
|
|
def _post_process_sigma_rule(self, rule_content: str) -> str:
|
|
"""Post-process SIGMA rule to ensure clean YAML format."""
|
|
lines = rule_content.split('\n')
|
|
cleaned_lines = []
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
|
|
# Skip obvious non-YAML content and prompt artifacts
|
|
if any(phrase in stripped.lower() for phrase in [
|
|
'please note', 'you should replace', 'this is a proof-of-concept',
|
|
'please make sure', 'note that', 'important:', 'remember to',
|
|
'analysis shows', 'based on the', 'the rule above', 'this rule',
|
|
'human:', 'cve id:', 'cve description:', 'poc code:', 'exploit code:',
|
|
'analyze this', 'create a', 'output only', 'generate a'
|
|
]):
|
|
continue
|
|
|
|
# Skip template variables
|
|
if '{' in stripped and '}' in stripped:
|
|
continue
|
|
|
|
# Skip lines that look like explanations
|
|
if stripped and not ':' in stripped and not stripped.startswith('-') and not stripped.startswith(' '):
|
|
# This might be explanatory text, skip it
|
|
if any(word in stripped.lower() for word in ['rule', 'detect', 'should', 'will', 'can', 'may']):
|
|
continue
|
|
|
|
# Skip empty explanatory sections
|
|
if stripped.lower() in ['explanation:', 'analysis:', 'notes:', 'important:', '']:
|
|
continue
|
|
|
|
cleaned_lines.append(line)
|
|
|
|
return '\n'.join(cleaned_lines).strip()
|
|
|
|
async def enhance_existing_rule(self,
|
|
existing_rule: str,
|
|
poc_content: str,
|
|
cve_id: str) -> Optional[str]:
|
|
"""
|
|
Enhance an existing SIGMA rule with PoC analysis.
|
|
|
|
Args:
|
|
existing_rule: Existing SIGMA rule YAML
|
|
poc_content: PoC code content
|
|
cve_id: CVE identifier
|
|
|
|
Returns:
|
|
Enhanced SIGMA rule or None if failed
|
|
"""
|
|
if not self.is_available():
|
|
return None
|
|
|
|
try:
|
|
system_message = """You are a SIGMA rule enhancement expert. Analyze the following PoC code and enhance the existing SIGMA rule with more specific detection patterns.
|
|
|
|
**Task:** Enhance the existing rule by:
|
|
1. Adding more specific detection patterns found in the PoC
|
|
2. Improving the condition logic
|
|
3. Adding relevant tags or MITRE ATT&CK mappings
|
|
4. Keeping the rule structure intact but making it more effective
|
|
|
|
Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|
|
|
user_template = """**CVE ID:** {cve_id}
|
|
|
|
**PoC Code:**
|
|
```
|
|
{poc_content}
|
|
```
|
|
|
|
**Existing SIGMA Rule:**
|
|
```yaml
|
|
{existing_rule}
|
|
```"""
|
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
SystemMessage(content=system_message),
|
|
HumanMessage(content=user_template)
|
|
])
|
|
|
|
chain = prompt | self.llm | self.output_parser
|
|
|
|
response = await chain.ainvoke({
|
|
"cve_id": cve_id,
|
|
"poc_content": poc_content[:3000],
|
|
"existing_rule": existing_rule
|
|
})
|
|
|
|
enhanced_rule = self._extract_sigma_rule(response)
|
|
enhanced_rule = self._post_process_sigma_rule(enhanced_rule)
|
|
logger.info(f"Successfully enhanced SIGMA rule for {cve_id}")
|
|
return enhanced_rule
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to enhance SIGMA rule for {cve_id}: {e}")
|
|
return None
|
|
|
|
def validate_sigma_rule(self, rule_content: str) -> bool:
|
|
"""Validate that the generated rule follows SIGMA specification."""
|
|
try:
|
|
# Parse as YAML
|
|
parsed = yaml.safe_load(rule_content)
|
|
|
|
if not isinstance(parsed, dict):
|
|
logger.warning("Rule must be a YAML dictionary")
|
|
return False
|
|
|
|
# Check MANDATORY fields per SIGMA spec
|
|
mandatory_fields = ['title', 'logsource', 'detection']
|
|
for field in mandatory_fields:
|
|
if field not in parsed:
|
|
logger.warning(f"Missing mandatory field: {field}")
|
|
return False
|
|
|
|
# Validate title
|
|
title = parsed.get('title', '')
|
|
if not isinstance(title, str) or len(title) > 256:
|
|
logger.warning("Title must be string ≤256 characters")
|
|
return False
|
|
|
|
# Validate logsource structure
|
|
logsource = parsed.get('logsource', {})
|
|
if not isinstance(logsource, dict):
|
|
logger.warning("Logsource must be a dictionary")
|
|
return False
|
|
|
|
# Validate detection structure
|
|
detection = parsed.get('detection', {})
|
|
if not isinstance(detection, dict):
|
|
logger.warning("Detection must be a dictionary")
|
|
return False
|
|
|
|
# Check for condition (can be in detection or at root)
|
|
has_condition = 'condition' in detection or 'condition' in parsed
|
|
if not has_condition:
|
|
logger.warning("Missing condition field")
|
|
return False
|
|
|
|
# Check for at least one selection
|
|
selection_found = any(key.startswith('selection') or key in ['selection', 'keywords', 'filter']
|
|
for key in detection.keys() if key != 'condition')
|
|
if not selection_found:
|
|
logger.warning("Detection must have at least one selection")
|
|
return False
|
|
|
|
# Validate status if present
|
|
status = parsed.get('status')
|
|
if status and status not in ['stable', 'test', 'experimental', 'deprecated', 'unsupported']:
|
|
logger.warning(f"Invalid status: {status}")
|
|
return False
|
|
|
|
logger.info("SIGMA rule validation passed")
|
|
return True
|
|
|
|
except yaml.YAMLError as e:
|
|
logger.warning(f"YAML parsing error: {e}")
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"Rule validation error: {e}")
|
|
return False
|
|
|
|
@classmethod
|
|
def get_available_providers(cls) -> List[Dict[str, Any]]:
|
|
"""Get list of available LLM providers and their configuration status."""
|
|
providers = []
|
|
|
|
for provider_name, provider_info in cls.SUPPORTED_PROVIDERS.items():
|
|
env_key = provider_info.get('env_key', '')
|
|
api_key_configured = bool(os.getenv(env_key))
|
|
|
|
providers.append({
|
|
'name': provider_name,
|
|
'models': provider_info.get('models', []),
|
|
'default_model': provider_info.get('default_model', ''),
|
|
'env_key': env_key,
|
|
'api_key_configured': api_key_configured,
|
|
'available': api_key_configured or provider_name == 'ollama'
|
|
})
|
|
|
|
return providers
|
|
|
|
def switch_provider(self, provider: str, model: str = None):
|
|
"""Switch to a different LLM provider and model."""
|
|
if provider not in self.SUPPORTED_PROVIDERS:
|
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
|
|
self.provider = provider
|
|
self.model = model or self._get_default_model(provider)
|
|
self._initialize_llm()
|
|
|
|
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
|
|
|
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
|
|
"""Check if an Ollama model is available locally"""
|
|
try:
|
|
import requests
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get('models', [])
|
|
for m in models:
|
|
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error checking Ollama models: {e}")
|
|
return False
|
|
|
|
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
|
|
"""Pull an Ollama model"""
|
|
try:
|
|
import requests
|
|
import json
|
|
|
|
# Use the pull API endpoint
|
|
payload = {"name": model}
|
|
response = requests.post(
|
|
f"{base_url}/api/pull",
|
|
json=payload,
|
|
timeout=300, # 5 minutes timeout for model download
|
|
stream=True
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Stream the response to monitor progress
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
data = json.loads(line.decode('utf-8'))
|
|
if data.get('status'):
|
|
logger.info(f"Ollama pull progress: {data.get('status')}")
|
|
if data.get('error'):
|
|
logger.error(f"Ollama pull error: {data.get('error')}")
|
|
return False
|
|
except json.JSONDecodeError:
|
|
continue
|
|
return True
|
|
else:
|
|
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error pulling Ollama model {model}: {e}")
|
|
return False
|
|
|
|
def _get_ollama_available_models(self) -> List[str]:
|
|
"""Get list of available Ollama models"""
|
|
try:
|
|
import requests
|
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get('models', [])
|
|
return [m.get('name', '') for m in models if m.get('name')]
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"Error getting Ollama models: {e}")
|
|
return [] |