add ollama to docker-compose for local model testing
This commit is contained in:
parent
3c120462ac
commit
08d6e33bbc
5 changed files with 873 additions and 1 deletions
|
@ -94,6 +94,15 @@ class LLMClient:
|
|||
elif self.provider == 'ollama':
|
||||
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||
|
||||
# Check if model is available, if not try to pull it
|
||||
if not self._check_ollama_model_available(base_url, self.model):
|
||||
logger.info(f"Model {self.model} not found, attempting to pull...")
|
||||
if self._pull_ollama_model(base_url, self.model):
|
||||
logger.info(f"Successfully pulled model {self.model}")
|
||||
else:
|
||||
logger.error(f"Failed to pull model {self.model}")
|
||||
return
|
||||
|
||||
self.llm = Ollama(
|
||||
model=self.model,
|
||||
base_url=base_url,
|
||||
|
@ -116,11 +125,20 @@ class LLMClient:
|
|||
def get_provider_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the current provider and configuration."""
|
||||
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
||||
|
||||
# For Ollama, get actually available models
|
||||
available_models = provider_info.get('models', [])
|
||||
if self.provider == 'ollama':
|
||||
ollama_models = self._get_ollama_available_models()
|
||||
if ollama_models:
|
||||
available_models = ollama_models
|
||||
|
||||
return {
|
||||
'provider': self.provider,
|
||||
'model': self.model,
|
||||
'available': self.is_available(),
|
||||
'supported_models': provider_info.get('models', []),
|
||||
'available_models': available_models,
|
||||
'env_key': provider_info.get('env_key', ''),
|
||||
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
|
||||
}
|
||||
|
@ -395,4 +413,72 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
|||
self.model = model or self._get_default_model(provider)
|
||||
self._initialize_llm()
|
||||
|
||||
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
||||
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
||||
|
||||
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
|
||||
"""Check if an Ollama model is available locally"""
|
||||
try:
|
||||
import requests
|
||||
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = data.get('models', [])
|
||||
for m in models:
|
||||
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Ollama models: {e}")
|
||||
return False
|
||||
|
||||
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
|
||||
"""Pull an Ollama model"""
|
||||
try:
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Use the pull API endpoint
|
||||
payload = {"name": model}
|
||||
response = requests.post(
|
||||
f"{base_url}/api/pull",
|
||||
json=payload,
|
||||
timeout=300, # 5 minutes timeout for model download
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Stream the response to monitor progress
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
if data.get('status'):
|
||||
logger.info(f"Ollama pull progress: {data.get('status')}")
|
||||
if data.get('error'):
|
||||
logger.error(f"Ollama pull error: {data.get('error')}")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling Ollama model {model}: {e}")
|
||||
return False
|
||||
|
||||
def _get_ollama_available_models(self) -> List[str]:
|
||||
"""Get list of available Ollama models"""
|
||||
try:
|
||||
import requests
|
||||
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = data.get('models', [])
|
||||
return [m.get('name', '') for m in models if m.get('name')]
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Ollama models: {e}")
|
||||
return []
|
|
@ -1461,6 +1461,7 @@ async def sync_references(request: ReferenceSyncRequest, background_tasks: Backg
|
|||
result = await client.bulk_sync_references(
|
||||
batch_size=request.batch_size,
|
||||
max_cves=request.max_cves,
|
||||
force_resync=request.force_resync,
|
||||
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
||||
)
|
||||
running_jobs[job_id]['result'] = result
|
||||
|
@ -2088,6 +2089,60 @@ async def get_running_jobs(db: Session = Depends(get_db)):
|
|||
logger.error(f"Error getting running jobs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/ollama-pull-model")
|
||||
async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks):
|
||||
"""Pull an Ollama model"""
|
||||
try:
|
||||
from llm_client import LLMClient
|
||||
|
||||
model = request.get('model')
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Model name is required")
|
||||
|
||||
# Create a background task to pull the model
|
||||
def pull_model_task():
|
||||
try:
|
||||
client = LLMClient(provider='ollama', model=model)
|
||||
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||
|
||||
if client._pull_ollama_model(base_url, model):
|
||||
logger.info(f"Successfully pulled Ollama model: {model}")
|
||||
else:
|
||||
logger.error(f"Failed to pull Ollama model: {model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in model pull task: {e}")
|
||||
|
||||
background_tasks.add_task(pull_model_task)
|
||||
|
||||
return {
|
||||
"message": f"Started pulling model {model}",
|
||||
"status": "started",
|
||||
"model": model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting model pull: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/ollama-models")
|
||||
async def get_ollama_models():
|
||||
"""Get available Ollama models"""
|
||||
try:
|
||||
from llm_client import LLMClient
|
||||
|
||||
client = LLMClient(provider='ollama')
|
||||
available_models = client._get_ollama_available_models()
|
||||
|
||||
return {
|
||||
"available_models": available_models,
|
||||
"total_models": len(available_models),
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Ollama models: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
|
603
backend/reference_client.py
Normal file
603
backend/reference_client.py
Normal file
|
@ -0,0 +1,603 @@
|
|||
"""
|
||||
Reference Data Extraction Client
|
||||
Extracts and analyzes text content from CVE references and KEV records
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text, func
|
||||
import hashlib
|
||||
from bs4 import BeautifulSoup
|
||||
import ssl
|
||||
import certifi
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ReferenceClient:
|
||||
"""Client for extracting and analyzing reference content from CVE and KEV records"""
|
||||
|
||||
def __init__(self, db_session: Session):
|
||||
self.db_session = db_session
|
||||
|
||||
# Rate limiting
|
||||
self.rate_limit_delay = 2.0 # 2 seconds between requests
|
||||
self.last_request_time = 0
|
||||
|
||||
# Cache for processed URLs
|
||||
self.url_cache = {}
|
||||
self.cache_ttl = 86400 # 24 hours cache
|
||||
|
||||
# SSL context for secure requests
|
||||
self.ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
# Allow self-signed certificates for some sites that might have issues
|
||||
self.ssl_context.check_hostname = False
|
||||
self.ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Common headers to avoid being blocked
|
||||
self.headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'Connection': 'keep-alive',
|
||||
'Upgrade-Insecure-Requests': '1',
|
||||
'Sec-Fetch-Dest': 'document',
|
||||
'Sec-Fetch-Mode': 'navigate',
|
||||
'Sec-Fetch-Site': 'none'
|
||||
}
|
||||
|
||||
# Supported reference types
|
||||
self.reference_types = {
|
||||
'security_advisory': ['security', 'advisory', 'bulletin', 'alert', 'cve', 'vulnerability'],
|
||||
'patch': ['patch', 'fix', 'update', 'hotfix', 'security-update'],
|
||||
'exploit': ['exploit', 'poc', 'proof-of-concept', 'github.com', 'exploit-db'],
|
||||
'technical_analysis': ['analysis', 'research', 'technical', 'writeup', 'blog'],
|
||||
'vendor_advisory': ['microsoft', 'apple', 'oracle', 'cisco', 'vmware', 'adobe'],
|
||||
'cve_database': ['cve.mitre.org', 'nvd.nist.gov', 'cve.org']
|
||||
}
|
||||
|
||||
async def _make_request(self, session: aiohttp.ClientSession, url: str) -> Optional[Tuple[str, str]]:
|
||||
"""Make a rate-limited request to fetch URL content"""
|
||||
try:
|
||||
# Rate limiting
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last = current_time - self.last_request_time
|
||||
if time_since_last < self.rate_limit_delay:
|
||||
await asyncio.sleep(self.rate_limit_delay - time_since_last)
|
||||
|
||||
# Check cache first
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
if url_hash in self.url_cache:
|
||||
cache_entry = self.url_cache[url_hash]
|
||||
if datetime.now().timestamp() - cache_entry['timestamp'] < self.cache_ttl:
|
||||
logger.info(f"Using cached content for {url}")
|
||||
return cache_entry['content'], cache_entry['content_type']
|
||||
|
||||
async with session.get(url, headers=self.headers) as response:
|
||||
self.last_request_time = asyncio.get_event_loop().time()
|
||||
|
||||
if response.status == 200:
|
||||
content_type = response.headers.get('content-type', '').lower()
|
||||
|
||||
# Only process text content
|
||||
if 'text/html' in content_type or 'text/plain' in content_type or 'application/json' in content_type:
|
||||
try:
|
||||
content = await response.text(encoding='utf-8', errors='ignore')
|
||||
except UnicodeDecodeError:
|
||||
# Fallback to binary content if text decode fails
|
||||
content_bytes = await response.read()
|
||||
content = content_bytes.decode('utf-8', errors='ignore')
|
||||
|
||||
# Cache the result
|
||||
self.url_cache[url_hash] = {
|
||||
'content': content,
|
||||
'content_type': content_type,
|
||||
'timestamp': datetime.now().timestamp()
|
||||
}
|
||||
|
||||
return content, content_type
|
||||
else:
|
||||
logger.warning(f"Unsupported content type {content_type} for {url}")
|
||||
return None, None
|
||||
elif response.status in [301, 302, 303, 307, 308]:
|
||||
logger.info(f"Redirect response {response.status} for {url}")
|
||||
return None, None
|
||||
else:
|
||||
logger.warning(f"Request failed: {response.status} for {url}")
|
||||
return None, None
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.warning(f"Client error fetching {url}: {e}")
|
||||
return None, None
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Timeout fetching {url}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error fetching {url}: {e}")
|
||||
return None, None
|
||||
|
||||
def _extract_text_from_html(self, html_content: str) -> str:
|
||||
"""Extract meaningful text from HTML content"""
|
||||
try:
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
|
||||
# Extract text from common content areas
|
||||
content_selectors = [
|
||||
'article', 'main', '.content', '#content',
|
||||
'.post-content', '.entry-content', 'section'
|
||||
]
|
||||
|
||||
text_content = ""
|
||||
for selector in content_selectors:
|
||||
elements = soup.select(selector)
|
||||
if elements:
|
||||
for element in elements:
|
||||
text_content += element.get_text(separator=' ', strip=True) + '\n'
|
||||
break
|
||||
|
||||
# If no structured content found, get all text
|
||||
if not text_content.strip():
|
||||
text_content = soup.get_text(separator=' ', strip=True)
|
||||
|
||||
# Clean up the text
|
||||
text_content = re.sub(r'\s+', ' ', text_content)
|
||||
text_content = text_content.strip()
|
||||
|
||||
return text_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting text from HTML: {e}")
|
||||
return ""
|
||||
|
||||
def _analyze_reference_content(self, url: str, content: str) -> Dict[str, Any]:
|
||||
"""Analyze reference content to extract security-relevant information"""
|
||||
analysis = {
|
||||
'url': url,
|
||||
'content_length': len(content),
|
||||
'reference_type': 'unknown',
|
||||
'security_keywords': [],
|
||||
'technical_indicators': [],
|
||||
'patch_information': [],
|
||||
'exploit_indicators': [],
|
||||
'cve_mentions': [],
|
||||
'severity_indicators': [],
|
||||
'mitigation_steps': [],
|
||||
'affected_products': [],
|
||||
'attack_vectors': [],
|
||||
'confidence_score': 0
|
||||
}
|
||||
|
||||
if not content:
|
||||
return analysis
|
||||
|
||||
content_lower = content.lower()
|
||||
|
||||
# Classify reference type
|
||||
domain = urlparse(url).netloc.lower()
|
||||
for ref_type, keywords in self.reference_types.items():
|
||||
if any(keyword in domain or keyword in content_lower for keyword in keywords):
|
||||
analysis['reference_type'] = ref_type
|
||||
break
|
||||
|
||||
# Extract CVE mentions
|
||||
cve_pattern = r'(CVE-\d{4}-\d{4,7})'
|
||||
cve_matches = re.findall(cve_pattern, content, re.IGNORECASE)
|
||||
analysis['cve_mentions'] = list(set(cve_matches))
|
||||
|
||||
# Security keywords
|
||||
security_keywords = [
|
||||
'vulnerability', 'exploit', 'attack', 'malware', 'backdoor',
|
||||
'privilege escalation', 'remote code execution', 'rce',
|
||||
'sql injection', 'xss', 'csrf', 'buffer overflow',
|
||||
'authentication bypass', 'authorization', 'injection',
|
||||
'denial of service', 'dos', 'ddos', 'ransomware'
|
||||
]
|
||||
|
||||
for keyword in security_keywords:
|
||||
if keyword in content_lower:
|
||||
analysis['security_keywords'].append(keyword)
|
||||
|
||||
# Technical indicators
|
||||
technical_patterns = [
|
||||
r'\b(function|method|class|variable)\s+\w+',
|
||||
r'\b(file|directory|path|folder)\s+[^\s]+',
|
||||
r'\b(port|service|protocol)\s+\d+',
|
||||
r'\b(registry|key|value)\s+[^\s]+',
|
||||
r'\b(process|executable|binary)\s+[^\s]+',
|
||||
r'\b(dll|exe|bat|ps1|sh|py|jar)\b',
|
||||
r'\b(http|https|ftp|smb|tcp|udp)://[^\s]+',
|
||||
r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
|
||||
]
|
||||
|
||||
for pattern in technical_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
if matches:
|
||||
analysis['technical_indicators'].extend(matches[:10]) # Limit to 10 per pattern
|
||||
|
||||
# Patch information
|
||||
patch_keywords = [
|
||||
'patch', 'fix', 'update', 'hotfix', 'security update',
|
||||
'kb\d+', 'ms\d+-\d+', 'version \d+\.\d+',
|
||||
'download', 'install', 'upgrade'
|
||||
]
|
||||
|
||||
for keyword in patch_keywords:
|
||||
if re.search(keyword, content_lower):
|
||||
analysis['patch_information'].append(keyword)
|
||||
|
||||
# Exploit indicators
|
||||
exploit_keywords = [
|
||||
'proof of concept', 'poc', 'exploit code', 'payload',
|
||||
'shellcode', 'reverse shell', 'metasploit', 'nmap',
|
||||
'vulnerability assessment', 'penetration test', 'bypass'
|
||||
]
|
||||
|
||||
for keyword in exploit_keywords:
|
||||
if keyword in content_lower:
|
||||
analysis['exploit_indicators'].append(keyword)
|
||||
|
||||
# Severity indicators
|
||||
severity_patterns = [
|
||||
r'\b(critical|high|medium|low)\s+(severity|risk|priority)',
|
||||
r'\b(cvss|score)\s*[:=]?\s*(\d+\.\d+|\d+)',
|
||||
r'\b(exploitability|impact)\s*[:=]?\s*(high|medium|low)'
|
||||
]
|
||||
|
||||
for pattern in severity_patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
if matches:
|
||||
analysis['severity_indicators'].extend([' '.join(match) if isinstance(match, tuple) else match for match in matches])
|
||||
|
||||
# Mitigation steps
|
||||
mitigation_keywords = [
|
||||
'mitigation', 'workaround', 'prevention', 'remediation',
|
||||
'disable', 'block', 'restrict', 'configure', 'setting'
|
||||
]
|
||||
|
||||
# Find sentences containing mitigation keywords
|
||||
sentences = re.split(r'[.!?]+', content)
|
||||
for sentence in sentences:
|
||||
if any(keyword in sentence.lower() for keyword in mitigation_keywords):
|
||||
if len(sentence.strip()) > 20: # Avoid very short sentences
|
||||
analysis['mitigation_steps'].append(sentence.strip()[:200]) # Limit length
|
||||
|
||||
# Calculate confidence score
|
||||
score = 0
|
||||
score += min(len(analysis['security_keywords']) * 5, 25)
|
||||
score += min(len(analysis['technical_indicators']) * 2, 20)
|
||||
score += min(len(analysis['cve_mentions']) * 10, 30)
|
||||
score += min(len(analysis['patch_information']) * 3, 15)
|
||||
score += min(len(analysis['exploit_indicators']) * 4, 20)
|
||||
|
||||
if analysis['reference_type'] != 'unknown':
|
||||
score += 10
|
||||
|
||||
analysis['confidence_score'] = min(score, 100)
|
||||
|
||||
# Clean up and deduplicate
|
||||
for key in ['security_keywords', 'technical_indicators', 'patch_information',
|
||||
'exploit_indicators', 'severity_indicators', 'mitigation_steps']:
|
||||
analysis[key] = list(set(analysis[key]))[:10] # Limit to 10 items each
|
||||
|
||||
return analysis
|
||||
|
||||
async def extract_reference_content(self, url: str) -> Optional[Dict[str, Any]]:
|
||||
"""Extract and analyze content from a single reference URL"""
|
||||
try:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=self.ssl_context,
|
||||
limit=100,
|
||||
limit_per_host=30,
|
||||
ttl_dns_cache=300,
|
||||
use_dns_cache=True,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60, connect=30)
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
content, content_type = await self._make_request(session, url)
|
||||
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Extract text from HTML if needed
|
||||
if content_type and 'text/html' in content_type:
|
||||
text_content = self._extract_text_from_html(content)
|
||||
else:
|
||||
text_content = content
|
||||
|
||||
# Analyze the content
|
||||
analysis = self._analyze_reference_content(url, text_content)
|
||||
|
||||
# Add metadata
|
||||
analysis.update({
|
||||
'extracted_at': datetime.utcnow().isoformat(),
|
||||
'content_type': content_type,
|
||||
'text_length': len(text_content),
|
||||
'source': 'reference_extraction'
|
||||
})
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting reference content from {url}: {e}")
|
||||
return None
|
||||
|
||||
async def sync_cve_references(self, cve_id: str) -> Dict[str, Any]:
|
||||
"""Sync reference data for a specific CVE"""
|
||||
from main import CVE, SigmaRule
|
||||
|
||||
# Get existing CVE
|
||||
cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
|
||||
if not cve:
|
||||
logger.warning(f"CVE {cve_id} not found in database")
|
||||
return {"error": "CVE not found"}
|
||||
|
||||
if not cve.reference_urls:
|
||||
logger.info(f"No reference URLs found for CVE {cve_id}")
|
||||
return {"cve_id": cve_id, "references_processed": 0}
|
||||
|
||||
logger.info(f"Processing {len(cve.reference_urls)} references for CVE {cve_id}")
|
||||
|
||||
processed_references = []
|
||||
successful_extractions = 0
|
||||
|
||||
for url in cve.reference_urls:
|
||||
try:
|
||||
# Extract reference content
|
||||
ref_analysis = await self.extract_reference_content(url)
|
||||
|
||||
if ref_analysis:
|
||||
processed_references.append(ref_analysis)
|
||||
successful_extractions += 1
|
||||
logger.info(f"Successfully extracted content from {url}")
|
||||
else:
|
||||
logger.warning(f"Failed to extract content from {url}")
|
||||
|
||||
# Small delay between requests
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing reference {url}: {e}")
|
||||
|
||||
# Update CVE with reference data
|
||||
cve.reference_data = {
|
||||
'reference_analysis': {
|
||||
'references': processed_references,
|
||||
'total_references': len(cve.reference_urls),
|
||||
'successful_extractions': successful_extractions,
|
||||
'extraction_rate': successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0,
|
||||
'extracted_at': datetime.utcnow().isoformat(),
|
||||
'source': 'reference_extraction'
|
||||
}
|
||||
}
|
||||
|
||||
cve.reference_sync_status = 'completed' if successful_extractions > 0 else 'failed'
|
||||
cve.reference_last_synced = datetime.utcnow()
|
||||
|
||||
cve.updated_at = datetime.utcnow()
|
||||
|
||||
# Update SIGMA rule with reference data
|
||||
sigma_rule = self.db_session.query(SigmaRule).filter(
|
||||
SigmaRule.cve_id == cve_id
|
||||
).first()
|
||||
|
||||
if sigma_rule:
|
||||
# Aggregate indicators from all references
|
||||
aggregated_indicators = {
|
||||
'security_keywords': [],
|
||||
'technical_indicators': [],
|
||||
'exploit_indicators': [],
|
||||
'patch_information': [],
|
||||
'attack_vectors': [],
|
||||
'mitigation_steps': []
|
||||
}
|
||||
|
||||
for ref in processed_references:
|
||||
for key in aggregated_indicators.keys():
|
||||
if key in ref:
|
||||
aggregated_indicators[key].extend(ref[key])
|
||||
|
||||
# Deduplicate
|
||||
for key in aggregated_indicators:
|
||||
aggregated_indicators[key] = list(set(aggregated_indicators[key]))
|
||||
|
||||
# Update rule with reference data
|
||||
if not sigma_rule.nomi_sec_data:
|
||||
sigma_rule.nomi_sec_data = {}
|
||||
|
||||
sigma_rule.nomi_sec_data['reference_analysis'] = {
|
||||
'aggregated_indicators': aggregated_indicators,
|
||||
'reference_count': len(processed_references),
|
||||
'high_confidence_references': len([r for r in processed_references if r.get('confidence_score', 0) > 70]),
|
||||
'reference_types': list(set([r.get('reference_type') for r in processed_references if r.get('reference_type') != 'unknown'])),
|
||||
'source': 'reference_extraction'
|
||||
}
|
||||
|
||||
# Update exploit indicators
|
||||
if sigma_rule.exploit_indicators:
|
||||
existing_indicators = json.loads(sigma_rule.exploit_indicators)
|
||||
else:
|
||||
existing_indicators = {}
|
||||
|
||||
for key, values in aggregated_indicators.items():
|
||||
if key not in existing_indicators:
|
||||
existing_indicators[key] = []
|
||||
existing_indicators[key].extend(values)
|
||||
existing_indicators[key] = list(set(existing_indicators[key]))
|
||||
|
||||
sigma_rule.exploit_indicators = json.dumps(existing_indicators)
|
||||
sigma_rule.updated_at = datetime.utcnow()
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
logger.info(f"Successfully synchronized reference data for {cve_id}")
|
||||
|
||||
return {
|
||||
"cve_id": cve_id,
|
||||
"references_processed": len(processed_references),
|
||||
"successful_extractions": successful_extractions,
|
||||
"extraction_rate": successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0,
|
||||
"high_confidence_references": len([r for r in processed_references if r.get('confidence_score', 0) > 70]),
|
||||
"source": "reference_extraction"
|
||||
}
|
||||
|
||||
async def bulk_sync_references(self, batch_size: int = 50, max_cves: int = None,
|
||||
force_resync: bool = False, cancellation_flag: Optional[callable] = None) -> Dict[str, Any]:
|
||||
"""Bulk synchronize reference data for multiple CVEs"""
|
||||
from main import CVE, BulkProcessingJob
|
||||
|
||||
# Create bulk processing job
|
||||
job = BulkProcessingJob(
|
||||
job_type='reference_sync',
|
||||
status='running',
|
||||
started_at=datetime.utcnow(),
|
||||
job_metadata={'batch_size': batch_size, 'max_cves': max_cves}
|
||||
)
|
||||
self.db_session.add(job)
|
||||
self.db_session.commit()
|
||||
|
||||
total_processed = 0
|
||||
total_references = 0
|
||||
successful_extractions = 0
|
||||
|
||||
try:
|
||||
# Get CVEs that have reference URLs but no reference analysis
|
||||
query = self.db_session.query(CVE).filter(
|
||||
CVE.reference_urls.isnot(None),
|
||||
func.array_length(CVE.reference_urls, 1) > 0
|
||||
)
|
||||
|
||||
# Filter out CVEs that already have reference analysis (unless force_resync is True)
|
||||
if not force_resync:
|
||||
query = query.filter(
|
||||
CVE.reference_sync_status != 'completed'
|
||||
)
|
||||
|
||||
if max_cves:
|
||||
cves = query.limit(max_cves).all()
|
||||
else:
|
||||
cves = query.all()
|
||||
|
||||
job.total_items = len(cves)
|
||||
self.db_session.commit()
|
||||
|
||||
logger.info(f"Starting bulk reference sync for {len(cves)} CVEs")
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(cves), batch_size):
|
||||
# Check for cancellation
|
||||
if cancellation_flag and cancellation_flag():
|
||||
logger.info("Bulk reference sync cancelled by user")
|
||||
job.status = 'cancelled'
|
||||
job.cancelled_at = datetime.utcnow()
|
||||
job.error_message = "Job cancelled by user"
|
||||
break
|
||||
|
||||
batch = cves[i:i + batch_size]
|
||||
|
||||
for cve in batch:
|
||||
# Check for cancellation
|
||||
if cancellation_flag and cancellation_flag():
|
||||
logger.info("Bulk reference sync cancelled by user")
|
||||
job.status = 'cancelled'
|
||||
job.cancelled_at = datetime.utcnow()
|
||||
job.error_message = "Job cancelled by user"
|
||||
break
|
||||
|
||||
try:
|
||||
result = await self.sync_cve_references(cve.cve_id)
|
||||
|
||||
if "error" not in result:
|
||||
total_processed += 1
|
||||
total_references += result.get("references_processed", 0)
|
||||
successful_extractions += result.get("successful_extractions", 0)
|
||||
else:
|
||||
job.failed_items += 1
|
||||
|
||||
job.processed_items += 1
|
||||
|
||||
# Longer delay for reference extraction to be respectful
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing references for {cve.cve_id}: {e}")
|
||||
job.failed_items += 1
|
||||
|
||||
# Break out of outer loop if cancelled
|
||||
if job.status == 'cancelled':
|
||||
break
|
||||
|
||||
# Commit after each batch
|
||||
self.db_session.commit()
|
||||
logger.info(f"Processed reference batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}")
|
||||
|
||||
# Update job status
|
||||
if job.status != 'cancelled':
|
||||
job.status = 'completed'
|
||||
job.completed_at = datetime.utcnow()
|
||||
|
||||
job.job_metadata.update({
|
||||
'total_processed': total_processed,
|
||||
'total_references': total_references,
|
||||
'successful_extractions': successful_extractions,
|
||||
'extraction_rate': successful_extractions / total_references if total_references > 0 else 0,
|
||||
'source': 'reference_extraction'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
job.status = 'failed'
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
logger.error(f"Bulk reference sync job failed: {e}")
|
||||
|
||||
finally:
|
||||
self.db_session.commit()
|
||||
|
||||
return {
|
||||
'job_id': str(job.id),
|
||||
'status': job.status,
|
||||
'total_processed': total_processed,
|
||||
'total_references': total_references,
|
||||
'successful_extractions': successful_extractions,
|
||||
'extraction_rate': successful_extractions / total_references if total_references > 0 else 0,
|
||||
'source': 'reference_extraction'
|
||||
}
|
||||
|
||||
async def get_reference_sync_status(self) -> Dict[str, Any]:
|
||||
"""Get reference synchronization status"""
|
||||
from main import CVE
|
||||
|
||||
# Count CVEs with reference URLs
|
||||
total_cves = self.db_session.query(CVE).count()
|
||||
|
||||
cves_with_refs = self.db_session.query(CVE).filter(
|
||||
CVE.reference_urls.isnot(None),
|
||||
func.array_length(CVE.reference_urls, 1) > 0
|
||||
).count()
|
||||
|
||||
# Count CVEs with reference analysis
|
||||
cves_with_analysis = self.db_session.query(CVE).filter(
|
||||
CVE.reference_sync_status == 'completed'
|
||||
).count()
|
||||
|
||||
return {
|
||||
'total_cves': total_cves,
|
||||
'cves_with_references': cves_with_refs,
|
||||
'cves_with_analysis': cves_with_analysis,
|
||||
'reference_coverage': (cves_with_refs / total_cves * 100) if total_cves > 0 else 0,
|
||||
'analysis_coverage': (cves_with_analysis / cves_with_refs * 100) if cves_with_refs > 0 else 0,
|
||||
'sync_status': 'active' if cves_with_analysis > 0 else 'pending',
|
||||
'source': 'reference_extraction'
|
||||
}
|
114
backend/setup_ollama.py
Executable file
114
backend/setup_ollama.py
Executable file
|
@ -0,0 +1,114 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Setup script to pull the default Ollama model on startup
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
|
||||
"""Wait for Ollama service to be ready"""
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/api/tags", timeout=5)
|
||||
if response.status_code == 200:
|
||||
logger.info("Ollama service is ready")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
|
||||
time.sleep(delay)
|
||||
|
||||
logger.error("Ollama service is not ready after maximum retries")
|
||||
return False
|
||||
|
||||
def check_model_exists(base_url: str, model: str) -> bool:
|
||||
"""Check if a model exists in Ollama"""
|
||||
try:
|
||||
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = data.get('models', [])
|
||||
for m in models:
|
||||
model_name = m.get('name', '')
|
||||
if model_name.startswith(model + ':') or model_name == model:
|
||||
logger.info(f"Model {model} already exists")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking models: {e}")
|
||||
return False
|
||||
|
||||
def pull_model(base_url: str, model: str) -> bool:
|
||||
"""Pull an Ollama model"""
|
||||
try:
|
||||
logger.info(f"Pulling model {model}...")
|
||||
payload = {"name": model}
|
||||
response = requests.post(
|
||||
f"{base_url}/api/pull",
|
||||
json=payload,
|
||||
timeout=1800, # 30 minutes timeout for model download
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Stream the response to monitor progress
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
status = data.get('status', '')
|
||||
if 'pulling' in status.lower() or 'downloading' in status.lower():
|
||||
logger.info(f"Ollama: {status}")
|
||||
elif data.get('error'):
|
||||
logger.error(f"Ollama pull error: {data.get('error')}")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
logger.info(f"Successfully pulled model {model}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
||||
logger.error(f"Response: {response.text}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling model {model}: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main setup function"""
|
||||
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
|
||||
model = os.getenv('LLM_MODEL', 'llama3.2')
|
||||
|
||||
logger.info(f"Setting up Ollama with model {model}")
|
||||
logger.info(f"Ollama URL: {base_url}")
|
||||
|
||||
# Wait for Ollama service to be ready
|
||||
if not wait_for_ollama(base_url):
|
||||
logger.error("Ollama service is not available")
|
||||
sys.exit(1)
|
||||
|
||||
# Check if model already exists
|
||||
if check_model_exists(base_url, model):
|
||||
logger.info(f"Model {model} is already available")
|
||||
sys.exit(0)
|
||||
|
||||
# Pull the model
|
||||
if pull_model(base_url, model):
|
||||
logger.info(f"Setup completed successfully - model {model} is ready")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error(f"Failed to pull model {model}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -32,6 +32,8 @@ services:
|
|||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
ollama-setup:
|
||||
condition: service_completed_successfully
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./github_poc_collector:/github_poc_collector
|
||||
|
@ -67,6 +69,18 @@ services:
|
|||
- OLLAMA_HOST=0.0.0.0
|
||||
restart: unless-stopped
|
||||
|
||||
ollama-setup:
|
||||
build: ./backend
|
||||
depends_on:
|
||||
- ollama
|
||||
environment:
|
||||
OLLAMA_BASE_URL: http://ollama:11434
|
||||
LLM_MODEL: llama3.2
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
command: python setup_ollama.py
|
||||
restart: "no"
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
redis_data:
|
||||
|
|
Loading…
Add table
Reference in a new issue