add ollama to docker-compose for local model testing
This commit is contained in:
parent
3c120462ac
commit
08d6e33bbc
5 changed files with 873 additions and 1 deletions
|
@ -94,6 +94,15 @@ class LLMClient:
|
||||||
elif self.provider == 'ollama':
|
elif self.provider == 'ollama':
|
||||||
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||||
|
|
||||||
|
# Check if model is available, if not try to pull it
|
||||||
|
if not self._check_ollama_model_available(base_url, self.model):
|
||||||
|
logger.info(f"Model {self.model} not found, attempting to pull...")
|
||||||
|
if self._pull_ollama_model(base_url, self.model):
|
||||||
|
logger.info(f"Successfully pulled model {self.model}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to pull model {self.model}")
|
||||||
|
return
|
||||||
|
|
||||||
self.llm = Ollama(
|
self.llm = Ollama(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
|
@ -116,11 +125,20 @@ class LLMClient:
|
||||||
def get_provider_info(self) -> Dict[str, Any]:
|
def get_provider_info(self) -> Dict[str, Any]:
|
||||||
"""Get information about the current provider and configuration."""
|
"""Get information about the current provider and configuration."""
|
||||||
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
|
||||||
|
|
||||||
|
# For Ollama, get actually available models
|
||||||
|
available_models = provider_info.get('models', [])
|
||||||
|
if self.provider == 'ollama':
|
||||||
|
ollama_models = self._get_ollama_available_models()
|
||||||
|
if ollama_models:
|
||||||
|
available_models = ollama_models
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'provider': self.provider,
|
'provider': self.provider,
|
||||||
'model': self.model,
|
'model': self.model,
|
||||||
'available': self.is_available(),
|
'available': self.is_available(),
|
||||||
'supported_models': provider_info.get('models', []),
|
'supported_models': provider_info.get('models', []),
|
||||||
|
'available_models': available_models,
|
||||||
'env_key': provider_info.get('env_key', ''),
|
'env_key': provider_info.get('env_key', ''),
|
||||||
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
|
'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
|
||||||
}
|
}
|
||||||
|
@ -396,3 +414,71 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
|
||||||
self._initialize_llm()
|
self._initialize_llm()
|
||||||
|
|
||||||
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
logger.info(f"Switched to provider: {provider} with model: {self.model}")
|
||||||
|
|
||||||
|
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
|
||||||
|
"""Check if an Ollama model is available locally"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = data.get('models', [])
|
||||||
|
for m in models:
|
||||||
|
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking Ollama models: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
|
||||||
|
"""Pull an Ollama model"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Use the pull API endpoint
|
||||||
|
payload = {"name": model}
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/api/pull",
|
||||||
|
json=payload,
|
||||||
|
timeout=300, # 5 minutes timeout for model download
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Stream the response to monitor progress
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.decode('utf-8'))
|
||||||
|
if data.get('status'):
|
||||||
|
logger.info(f"Ollama pull progress: {data.get('status')}")
|
||||||
|
if data.get('error'):
|
||||||
|
logger.error(f"Ollama pull error: {data.get('error')}")
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error pulling Ollama model {model}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_ollama_available_models(self) -> List[str]:
|
||||||
|
"""Get list of available Ollama models"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||||
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = data.get('models', [])
|
||||||
|
return [m.get('name', '') for m in models if m.get('name')]
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting Ollama models: {e}")
|
||||||
|
return []
|
|
@ -1461,6 +1461,7 @@ async def sync_references(request: ReferenceSyncRequest, background_tasks: Backg
|
||||||
result = await client.bulk_sync_references(
|
result = await client.bulk_sync_references(
|
||||||
batch_size=request.batch_size,
|
batch_size=request.batch_size,
|
||||||
max_cves=request.max_cves,
|
max_cves=request.max_cves,
|
||||||
|
force_resync=request.force_resync,
|
||||||
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
|
||||||
)
|
)
|
||||||
running_jobs[job_id]['result'] = result
|
running_jobs[job_id]['result'] = result
|
||||||
|
@ -2088,6 +2089,60 @@ async def get_running_jobs(db: Session = Depends(get_db)):
|
||||||
logger.error(f"Error getting running jobs: {e}")
|
logger.error(f"Error getting running jobs: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post("/api/ollama-pull-model")
|
||||||
|
async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks):
|
||||||
|
"""Pull an Ollama model"""
|
||||||
|
try:
|
||||||
|
from llm_client import LLMClient
|
||||||
|
|
||||||
|
model = request.get('model')
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=400, detail="Model name is required")
|
||||||
|
|
||||||
|
# Create a background task to pull the model
|
||||||
|
def pull_model_task():
|
||||||
|
try:
|
||||||
|
client = LLMClient(provider='ollama', model=model)
|
||||||
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
|
||||||
|
|
||||||
|
if client._pull_ollama_model(base_url, model):
|
||||||
|
logger.info(f"Successfully pulled Ollama model: {model}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to pull Ollama model: {model}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in model pull task: {e}")
|
||||||
|
|
||||||
|
background_tasks.add_task(pull_model_task)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Started pulling model {model}",
|
||||||
|
"status": "started",
|
||||||
|
"model": model
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting model pull: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/api/ollama-models")
|
||||||
|
async def get_ollama_models():
|
||||||
|
"""Get available Ollama models"""
|
||||||
|
try:
|
||||||
|
from llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient(provider='ollama')
|
||||||
|
available_models = client._get_ollama_available_models()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"available_models": available_models,
|
||||||
|
"total_models": len(available_models),
|
||||||
|
"status": "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting Ollama models: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
603
backend/reference_client.py
Normal file
603
backend/reference_client.py
Normal file
|
@ -0,0 +1,603 @@
|
||||||
|
"""
|
||||||
|
Reference Data Extraction Client
|
||||||
|
Extracts and analyzes text content from CVE references and KEV records
|
||||||
|
"""
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
from urllib.parse import urlparse, urljoin
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import text, func
|
||||||
|
import hashlib
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import ssl
|
||||||
|
import certifi
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ReferenceClient:
|
||||||
|
"""Client for extracting and analyzing reference content from CVE and KEV records"""
|
||||||
|
|
||||||
|
def __init__(self, db_session: Session):
|
||||||
|
self.db_session = db_session
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
self.rate_limit_delay = 2.0 # 2 seconds between requests
|
||||||
|
self.last_request_time = 0
|
||||||
|
|
||||||
|
# Cache for processed URLs
|
||||||
|
self.url_cache = {}
|
||||||
|
self.cache_ttl = 86400 # 24 hours cache
|
||||||
|
|
||||||
|
# SSL context for secure requests
|
||||||
|
self.ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||||
|
# Allow self-signed certificates for some sites that might have issues
|
||||||
|
self.ssl_context.check_hostname = False
|
||||||
|
self.ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Common headers to avoid being blocked
|
||||||
|
self.headers = {
|
||||||
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||||
|
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
|
||||||
|
'Accept-Language': 'en-US,en;q=0.5',
|
||||||
|
'Accept-Encoding': 'gzip, deflate, br',
|
||||||
|
'Connection': 'keep-alive',
|
||||||
|
'Upgrade-Insecure-Requests': '1',
|
||||||
|
'Sec-Fetch-Dest': 'document',
|
||||||
|
'Sec-Fetch-Mode': 'navigate',
|
||||||
|
'Sec-Fetch-Site': 'none'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Supported reference types
|
||||||
|
self.reference_types = {
|
||||||
|
'security_advisory': ['security', 'advisory', 'bulletin', 'alert', 'cve', 'vulnerability'],
|
||||||
|
'patch': ['patch', 'fix', 'update', 'hotfix', 'security-update'],
|
||||||
|
'exploit': ['exploit', 'poc', 'proof-of-concept', 'github.com', 'exploit-db'],
|
||||||
|
'technical_analysis': ['analysis', 'research', 'technical', 'writeup', 'blog'],
|
||||||
|
'vendor_advisory': ['microsoft', 'apple', 'oracle', 'cisco', 'vmware', 'adobe'],
|
||||||
|
'cve_database': ['cve.mitre.org', 'nvd.nist.gov', 'cve.org']
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _make_request(self, session: aiohttp.ClientSession, url: str) -> Optional[Tuple[str, str]]:
|
||||||
|
"""Make a rate-limited request to fetch URL content"""
|
||||||
|
try:
|
||||||
|
# Rate limiting
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
time_since_last = current_time - self.last_request_time
|
||||||
|
if time_since_last < self.rate_limit_delay:
|
||||||
|
await asyncio.sleep(self.rate_limit_delay - time_since_last)
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||||
|
if url_hash in self.url_cache:
|
||||||
|
cache_entry = self.url_cache[url_hash]
|
||||||
|
if datetime.now().timestamp() - cache_entry['timestamp'] < self.cache_ttl:
|
||||||
|
logger.info(f"Using cached content for {url}")
|
||||||
|
return cache_entry['content'], cache_entry['content_type']
|
||||||
|
|
||||||
|
async with session.get(url, headers=self.headers) as response:
|
||||||
|
self.last_request_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
if response.status == 200:
|
||||||
|
content_type = response.headers.get('content-type', '').lower()
|
||||||
|
|
||||||
|
# Only process text content
|
||||||
|
if 'text/html' in content_type or 'text/plain' in content_type or 'application/json' in content_type:
|
||||||
|
try:
|
||||||
|
content = await response.text(encoding='utf-8', errors='ignore')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Fallback to binary content if text decode fails
|
||||||
|
content_bytes = await response.read()
|
||||||
|
content = content_bytes.decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
self.url_cache[url_hash] = {
|
||||||
|
'content': content,
|
||||||
|
'content_type': content_type,
|
||||||
|
'timestamp': datetime.now().timestamp()
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, content_type
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported content type {content_type} for {url}")
|
||||||
|
return None, None
|
||||||
|
elif response.status in [301, 302, 303, 307, 308]:
|
||||||
|
logger.info(f"Redirect response {response.status} for {url}")
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
logger.warning(f"Request failed: {response.status} for {url}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.warning(f"Client error fetching {url}: {e}")
|
||||||
|
return None, None
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Timeout fetching {url}")
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error fetching {url}: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _extract_text_from_html(self, html_content: str) -> str:
|
||||||
|
"""Extract meaningful text from HTML content"""
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(html_content, 'html.parser')
|
||||||
|
|
||||||
|
# Remove script and style elements
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.decompose()
|
||||||
|
|
||||||
|
# Extract text from common content areas
|
||||||
|
content_selectors = [
|
||||||
|
'article', 'main', '.content', '#content',
|
||||||
|
'.post-content', '.entry-content', 'section'
|
||||||
|
]
|
||||||
|
|
||||||
|
text_content = ""
|
||||||
|
for selector in content_selectors:
|
||||||
|
elements = soup.select(selector)
|
||||||
|
if elements:
|
||||||
|
for element in elements:
|
||||||
|
text_content += element.get_text(separator=' ', strip=True) + '\n'
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no structured content found, get all text
|
||||||
|
if not text_content.strip():
|
||||||
|
text_content = soup.get_text(separator=' ', strip=True)
|
||||||
|
|
||||||
|
# Clean up the text
|
||||||
|
text_content = re.sub(r'\s+', ' ', text_content)
|
||||||
|
text_content = text_content.strip()
|
||||||
|
|
||||||
|
return text_content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting text from HTML: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _analyze_reference_content(self, url: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""Analyze reference content to extract security-relevant information"""
|
||||||
|
analysis = {
|
||||||
|
'url': url,
|
||||||
|
'content_length': len(content),
|
||||||
|
'reference_type': 'unknown',
|
||||||
|
'security_keywords': [],
|
||||||
|
'technical_indicators': [],
|
||||||
|
'patch_information': [],
|
||||||
|
'exploit_indicators': [],
|
||||||
|
'cve_mentions': [],
|
||||||
|
'severity_indicators': [],
|
||||||
|
'mitigation_steps': [],
|
||||||
|
'affected_products': [],
|
||||||
|
'attack_vectors': [],
|
||||||
|
'confidence_score': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
content_lower = content.lower()
|
||||||
|
|
||||||
|
# Classify reference type
|
||||||
|
domain = urlparse(url).netloc.lower()
|
||||||
|
for ref_type, keywords in self.reference_types.items():
|
||||||
|
if any(keyword in domain or keyword in content_lower for keyword in keywords):
|
||||||
|
analysis['reference_type'] = ref_type
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract CVE mentions
|
||||||
|
cve_pattern = r'(CVE-\d{4}-\d{4,7})'
|
||||||
|
cve_matches = re.findall(cve_pattern, content, re.IGNORECASE)
|
||||||
|
analysis['cve_mentions'] = list(set(cve_matches))
|
||||||
|
|
||||||
|
# Security keywords
|
||||||
|
security_keywords = [
|
||||||
|
'vulnerability', 'exploit', 'attack', 'malware', 'backdoor',
|
||||||
|
'privilege escalation', 'remote code execution', 'rce',
|
||||||
|
'sql injection', 'xss', 'csrf', 'buffer overflow',
|
||||||
|
'authentication bypass', 'authorization', 'injection',
|
||||||
|
'denial of service', 'dos', 'ddos', 'ransomware'
|
||||||
|
]
|
||||||
|
|
||||||
|
for keyword in security_keywords:
|
||||||
|
if keyword in content_lower:
|
||||||
|
analysis['security_keywords'].append(keyword)
|
||||||
|
|
||||||
|
# Technical indicators
|
||||||
|
technical_patterns = [
|
||||||
|
r'\b(function|method|class|variable)\s+\w+',
|
||||||
|
r'\b(file|directory|path|folder)\s+[^\s]+',
|
||||||
|
r'\b(port|service|protocol)\s+\d+',
|
||||||
|
r'\b(registry|key|value)\s+[^\s]+',
|
||||||
|
r'\b(process|executable|binary)\s+[^\s]+',
|
||||||
|
r'\b(dll|exe|bat|ps1|sh|py|jar)\b',
|
||||||
|
r'\b(http|https|ftp|smb|tcp|udp)://[^\s]+',
|
||||||
|
r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in technical_patterns:
|
||||||
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||||
|
if matches:
|
||||||
|
analysis['technical_indicators'].extend(matches[:10]) # Limit to 10 per pattern
|
||||||
|
|
||||||
|
# Patch information
|
||||||
|
patch_keywords = [
|
||||||
|
'patch', 'fix', 'update', 'hotfix', 'security update',
|
||||||
|
'kb\d+', 'ms\d+-\d+', 'version \d+\.\d+',
|
||||||
|
'download', 'install', 'upgrade'
|
||||||
|
]
|
||||||
|
|
||||||
|
for keyword in patch_keywords:
|
||||||
|
if re.search(keyword, content_lower):
|
||||||
|
analysis['patch_information'].append(keyword)
|
||||||
|
|
||||||
|
# Exploit indicators
|
||||||
|
exploit_keywords = [
|
||||||
|
'proof of concept', 'poc', 'exploit code', 'payload',
|
||||||
|
'shellcode', 'reverse shell', 'metasploit', 'nmap',
|
||||||
|
'vulnerability assessment', 'penetration test', 'bypass'
|
||||||
|
]
|
||||||
|
|
||||||
|
for keyword in exploit_keywords:
|
||||||
|
if keyword in content_lower:
|
||||||
|
analysis['exploit_indicators'].append(keyword)
|
||||||
|
|
||||||
|
# Severity indicators
|
||||||
|
severity_patterns = [
|
||||||
|
r'\b(critical|high|medium|low)\s+(severity|risk|priority)',
|
||||||
|
r'\b(cvss|score)\s*[:=]?\s*(\d+\.\d+|\d+)',
|
||||||
|
r'\b(exploitability|impact)\s*[:=]?\s*(high|medium|low)'
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in severity_patterns:
|
||||||
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||||
|
if matches:
|
||||||
|
analysis['severity_indicators'].extend([' '.join(match) if isinstance(match, tuple) else match for match in matches])
|
||||||
|
|
||||||
|
# Mitigation steps
|
||||||
|
mitigation_keywords = [
|
||||||
|
'mitigation', 'workaround', 'prevention', 'remediation',
|
||||||
|
'disable', 'block', 'restrict', 'configure', 'setting'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Find sentences containing mitigation keywords
|
||||||
|
sentences = re.split(r'[.!?]+', content)
|
||||||
|
for sentence in sentences:
|
||||||
|
if any(keyword in sentence.lower() for keyword in mitigation_keywords):
|
||||||
|
if len(sentence.strip()) > 20: # Avoid very short sentences
|
||||||
|
analysis['mitigation_steps'].append(sentence.strip()[:200]) # Limit length
|
||||||
|
|
||||||
|
# Calculate confidence score
|
||||||
|
score = 0
|
||||||
|
score += min(len(analysis['security_keywords']) * 5, 25)
|
||||||
|
score += min(len(analysis['technical_indicators']) * 2, 20)
|
||||||
|
score += min(len(analysis['cve_mentions']) * 10, 30)
|
||||||
|
score += min(len(analysis['patch_information']) * 3, 15)
|
||||||
|
score += min(len(analysis['exploit_indicators']) * 4, 20)
|
||||||
|
|
||||||
|
if analysis['reference_type'] != 'unknown':
|
||||||
|
score += 10
|
||||||
|
|
||||||
|
analysis['confidence_score'] = min(score, 100)
|
||||||
|
|
||||||
|
# Clean up and deduplicate
|
||||||
|
for key in ['security_keywords', 'technical_indicators', 'patch_information',
|
||||||
|
'exploit_indicators', 'severity_indicators', 'mitigation_steps']:
|
||||||
|
analysis[key] = list(set(analysis[key]))[:10] # Limit to 10 items each
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
async def extract_reference_content(self, url: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Extract and analyze content from a single reference URL"""
|
||||||
|
try:
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
ssl=self.ssl_context,
|
||||||
|
limit=100,
|
||||||
|
limit_per_host=30,
|
||||||
|
ttl_dns_cache=300,
|
||||||
|
use_dns_cache=True,
|
||||||
|
)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=60, connect=30)
|
||||||
|
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||||
|
content, content_type = await self._make_request(session, url)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract text from HTML if needed
|
||||||
|
if content_type and 'text/html' in content_type:
|
||||||
|
text_content = self._extract_text_from_html(content)
|
||||||
|
else:
|
||||||
|
text_content = content
|
||||||
|
|
||||||
|
# Analyze the content
|
||||||
|
analysis = self._analyze_reference_content(url, text_content)
|
||||||
|
|
||||||
|
# Add metadata
|
||||||
|
analysis.update({
|
||||||
|
'extracted_at': datetime.utcnow().isoformat(),
|
||||||
|
'content_type': content_type,
|
||||||
|
'text_length': len(text_content),
|
||||||
|
'source': 'reference_extraction'
|
||||||
|
})
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting reference content from {url}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def sync_cve_references(self, cve_id: str) -> Dict[str, Any]:
|
||||||
|
"""Sync reference data for a specific CVE"""
|
||||||
|
from main import CVE, SigmaRule
|
||||||
|
|
||||||
|
# Get existing CVE
|
||||||
|
cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
|
||||||
|
if not cve:
|
||||||
|
logger.warning(f"CVE {cve_id} not found in database")
|
||||||
|
return {"error": "CVE not found"}
|
||||||
|
|
||||||
|
if not cve.reference_urls:
|
||||||
|
logger.info(f"No reference URLs found for CVE {cve_id}")
|
||||||
|
return {"cve_id": cve_id, "references_processed": 0}
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(cve.reference_urls)} references for CVE {cve_id}")
|
||||||
|
|
||||||
|
processed_references = []
|
||||||
|
successful_extractions = 0
|
||||||
|
|
||||||
|
for url in cve.reference_urls:
|
||||||
|
try:
|
||||||
|
# Extract reference content
|
||||||
|
ref_analysis = await self.extract_reference_content(url)
|
||||||
|
|
||||||
|
if ref_analysis:
|
||||||
|
processed_references.append(ref_analysis)
|
||||||
|
successful_extractions += 1
|
||||||
|
logger.info(f"Successfully extracted content from {url}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to extract content from {url}")
|
||||||
|
|
||||||
|
# Small delay between requests
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing reference {url}: {e}")
|
||||||
|
|
||||||
|
# Update CVE with reference data
|
||||||
|
cve.reference_data = {
|
||||||
|
'reference_analysis': {
|
||||||
|
'references': processed_references,
|
||||||
|
'total_references': len(cve.reference_urls),
|
||||||
|
'successful_extractions': successful_extractions,
|
||||||
|
'extraction_rate': successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0,
|
||||||
|
'extracted_at': datetime.utcnow().isoformat(),
|
||||||
|
'source': 'reference_extraction'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cve.reference_sync_status = 'completed' if successful_extractions > 0 else 'failed'
|
||||||
|
cve.reference_last_synced = datetime.utcnow()
|
||||||
|
|
||||||
|
cve.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
# Update SIGMA rule with reference data
|
||||||
|
sigma_rule = self.db_session.query(SigmaRule).filter(
|
||||||
|
SigmaRule.cve_id == cve_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if sigma_rule:
|
||||||
|
# Aggregate indicators from all references
|
||||||
|
aggregated_indicators = {
|
||||||
|
'security_keywords': [],
|
||||||
|
'technical_indicators': [],
|
||||||
|
'exploit_indicators': [],
|
||||||
|
'patch_information': [],
|
||||||
|
'attack_vectors': [],
|
||||||
|
'mitigation_steps': []
|
||||||
|
}
|
||||||
|
|
||||||
|
for ref in processed_references:
|
||||||
|
for key in aggregated_indicators.keys():
|
||||||
|
if key in ref:
|
||||||
|
aggregated_indicators[key].extend(ref[key])
|
||||||
|
|
||||||
|
# Deduplicate
|
||||||
|
for key in aggregated_indicators:
|
||||||
|
aggregated_indicators[key] = list(set(aggregated_indicators[key]))
|
||||||
|
|
||||||
|
# Update rule with reference data
|
||||||
|
if not sigma_rule.nomi_sec_data:
|
||||||
|
sigma_rule.nomi_sec_data = {}
|
||||||
|
|
||||||
|
sigma_rule.nomi_sec_data['reference_analysis'] = {
|
||||||
|
'aggregated_indicators': aggregated_indicators,
|
||||||
|
'reference_count': len(processed_references),
|
||||||
|
'high_confidence_references': len([r for r in processed_references if r.get('confidence_score', 0) > 70]),
|
||||||
|
'reference_types': list(set([r.get('reference_type') for r in processed_references if r.get('reference_type') != 'unknown'])),
|
||||||
|
'source': 'reference_extraction'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update exploit indicators
|
||||||
|
if sigma_rule.exploit_indicators:
|
||||||
|
existing_indicators = json.loads(sigma_rule.exploit_indicators)
|
||||||
|
else:
|
||||||
|
existing_indicators = {}
|
||||||
|
|
||||||
|
for key, values in aggregated_indicators.items():
|
||||||
|
if key not in existing_indicators:
|
||||||
|
existing_indicators[key] = []
|
||||||
|
existing_indicators[key].extend(values)
|
||||||
|
existing_indicators[key] = list(set(existing_indicators[key]))
|
||||||
|
|
||||||
|
sigma_rule.exploit_indicators = json.dumps(existing_indicators)
|
||||||
|
sigma_rule.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
self.db_session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Successfully synchronized reference data for {cve_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cve_id": cve_id,
|
||||||
|
"references_processed": len(processed_references),
|
||||||
|
"successful_extractions": successful_extractions,
|
||||||
|
"extraction_rate": successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0,
|
||||||
|
"high_confidence_references": len([r for r in processed_references if r.get('confidence_score', 0) > 70]),
|
||||||
|
"source": "reference_extraction"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def bulk_sync_references(self, batch_size: int = 50, max_cves: int = None,
|
||||||
|
force_resync: bool = False, cancellation_flag: Optional[callable] = None) -> Dict[str, Any]:
|
||||||
|
"""Bulk synchronize reference data for multiple CVEs"""
|
||||||
|
from main import CVE, BulkProcessingJob
|
||||||
|
|
||||||
|
# Create bulk processing job
|
||||||
|
job = BulkProcessingJob(
|
||||||
|
job_type='reference_sync',
|
||||||
|
status='running',
|
||||||
|
started_at=datetime.utcnow(),
|
||||||
|
job_metadata={'batch_size': batch_size, 'max_cves': max_cves}
|
||||||
|
)
|
||||||
|
self.db_session.add(job)
|
||||||
|
self.db_session.commit()
|
||||||
|
|
||||||
|
total_processed = 0
|
||||||
|
total_references = 0
|
||||||
|
successful_extractions = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get CVEs that have reference URLs but no reference analysis
|
||||||
|
query = self.db_session.query(CVE).filter(
|
||||||
|
CVE.reference_urls.isnot(None),
|
||||||
|
func.array_length(CVE.reference_urls, 1) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out CVEs that already have reference analysis (unless force_resync is True)
|
||||||
|
if not force_resync:
|
||||||
|
query = query.filter(
|
||||||
|
CVE.reference_sync_status != 'completed'
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_cves:
|
||||||
|
cves = query.limit(max_cves).all()
|
||||||
|
else:
|
||||||
|
cves = query.all()
|
||||||
|
|
||||||
|
job.total_items = len(cves)
|
||||||
|
self.db_session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Starting bulk reference sync for {len(cves)} CVEs")
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
for i in range(0, len(cves), batch_size):
|
||||||
|
# Check for cancellation
|
||||||
|
if cancellation_flag and cancellation_flag():
|
||||||
|
logger.info("Bulk reference sync cancelled by user")
|
||||||
|
job.status = 'cancelled'
|
||||||
|
job.cancelled_at = datetime.utcnow()
|
||||||
|
job.error_message = "Job cancelled by user"
|
||||||
|
break
|
||||||
|
|
||||||
|
batch = cves[i:i + batch_size]
|
||||||
|
|
||||||
|
for cve in batch:
|
||||||
|
# Check for cancellation
|
||||||
|
if cancellation_flag and cancellation_flag():
|
||||||
|
logger.info("Bulk reference sync cancelled by user")
|
||||||
|
job.status = 'cancelled'
|
||||||
|
job.cancelled_at = datetime.utcnow()
|
||||||
|
job.error_message = "Job cancelled by user"
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.sync_cve_references(cve.cve_id)
|
||||||
|
|
||||||
|
if "error" not in result:
|
||||||
|
total_processed += 1
|
||||||
|
total_references += result.get("references_processed", 0)
|
||||||
|
successful_extractions += result.get("successful_extractions", 0)
|
||||||
|
else:
|
||||||
|
job.failed_items += 1
|
||||||
|
|
||||||
|
job.processed_items += 1
|
||||||
|
|
||||||
|
# Longer delay for reference extraction to be respectful
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing references for {cve.cve_id}: {e}")
|
||||||
|
job.failed_items += 1
|
||||||
|
|
||||||
|
# Break out of outer loop if cancelled
|
||||||
|
if job.status == 'cancelled':
|
||||||
|
break
|
||||||
|
|
||||||
|
# Commit after each batch
|
||||||
|
self.db_session.commit()
|
||||||
|
logger.info(f"Processed reference batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}")
|
||||||
|
|
||||||
|
# Update job status
|
||||||
|
if job.status != 'cancelled':
|
||||||
|
job.status = 'completed'
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
job.job_metadata.update({
|
||||||
|
'total_processed': total_processed,
|
||||||
|
'total_references': total_references,
|
||||||
|
'successful_extractions': successful_extractions,
|
||||||
|
'extraction_rate': successful_extractions / total_references if total_references > 0 else 0,
|
||||||
|
'source': 'reference_extraction'
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
job.status = 'failed'
|
||||||
|
job.error_message = str(e)
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
logger.error(f"Bulk reference sync job failed: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.db_session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'job_id': str(job.id),
|
||||||
|
'status': job.status,
|
||||||
|
'total_processed': total_processed,
|
||||||
|
'total_references': total_references,
|
||||||
|
'successful_extractions': successful_extractions,
|
||||||
|
'extraction_rate': successful_extractions / total_references if total_references > 0 else 0,
|
||||||
|
'source': 'reference_extraction'
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_reference_sync_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get reference synchronization status"""
|
||||||
|
from main import CVE
|
||||||
|
|
||||||
|
# Count CVEs with reference URLs
|
||||||
|
total_cves = self.db_session.query(CVE).count()
|
||||||
|
|
||||||
|
cves_with_refs = self.db_session.query(CVE).filter(
|
||||||
|
CVE.reference_urls.isnot(None),
|
||||||
|
func.array_length(CVE.reference_urls, 1) > 0
|
||||||
|
).count()
|
||||||
|
|
||||||
|
# Count CVEs with reference analysis
|
||||||
|
cves_with_analysis = self.db_session.query(CVE).filter(
|
||||||
|
CVE.reference_sync_status == 'completed'
|
||||||
|
).count()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_cves': total_cves,
|
||||||
|
'cves_with_references': cves_with_refs,
|
||||||
|
'cves_with_analysis': cves_with_analysis,
|
||||||
|
'reference_coverage': (cves_with_refs / total_cves * 100) if total_cves > 0 else 0,
|
||||||
|
'analysis_coverage': (cves_with_analysis / cves_with_refs * 100) if cves_with_refs > 0 else 0,
|
||||||
|
'sync_status': 'active' if cves_with_analysis > 0 else 'pending',
|
||||||
|
'source': 'reference_extraction'
|
||||||
|
}
|
114
backend/setup_ollama.py
Executable file
114
backend/setup_ollama.py
Executable file
|
@ -0,0 +1,114 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Setup script to pull the default Ollama model on startup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
|
||||||
|
"""Wait for Ollama service to be ready"""
|
||||||
|
for i in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{base_url}/api/tags", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info("Ollama service is ready")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
logger.error("Ollama service is not ready after maximum retries")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_model_exists(base_url: str, model: str) -> bool:
|
||||||
|
"""Check if a model exists in Ollama"""
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{base_url}/api/tags", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = data.get('models', [])
|
||||||
|
for m in models:
|
||||||
|
model_name = m.get('name', '')
|
||||||
|
if model_name.startswith(model + ':') or model_name == model:
|
||||||
|
logger.info(f"Model {model} already exists")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking models: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def pull_model(base_url: str, model: str) -> bool:
|
||||||
|
"""Pull an Ollama model"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Pulling model {model}...")
|
||||||
|
payload = {"name": model}
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/api/pull",
|
||||||
|
json=payload,
|
||||||
|
timeout=1800, # 30 minutes timeout for model download
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Stream the response to monitor progress
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.decode('utf-8'))
|
||||||
|
status = data.get('status', '')
|
||||||
|
if 'pulling' in status.lower() or 'downloading' in status.lower():
|
||||||
|
logger.info(f"Ollama: {status}")
|
||||||
|
elif data.get('error'):
|
||||||
|
logger.error(f"Ollama pull error: {data.get('error')}")
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Successfully pulled model {model}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
|
||||||
|
logger.error(f"Response: {response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error pulling model {model}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main setup function"""
|
||||||
|
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
|
||||||
|
model = os.getenv('LLM_MODEL', 'llama3.2')
|
||||||
|
|
||||||
|
logger.info(f"Setting up Ollama with model {model}")
|
||||||
|
logger.info(f"Ollama URL: {base_url}")
|
||||||
|
|
||||||
|
# Wait for Ollama service to be ready
|
||||||
|
if not wait_for_ollama(base_url):
|
||||||
|
logger.error("Ollama service is not available")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Check if model already exists
|
||||||
|
if check_model_exists(base_url, model):
|
||||||
|
logger.info(f"Model {model} is already available")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Pull the model
|
||||||
|
if pull_model(base_url, model):
|
||||||
|
logger.info(f"Setup completed successfully - model {model} is ready")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to pull model {model}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -32,6 +32,8 @@ services:
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
ollama-setup:
|
||||||
|
condition: service_completed_successfully
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend:/app
|
- ./backend:/app
|
||||||
- ./github_poc_collector:/github_poc_collector
|
- ./github_poc_collector:/github_poc_collector
|
||||||
|
@ -67,6 +69,18 @@ services:
|
||||||
- OLLAMA_HOST=0.0.0.0
|
- OLLAMA_HOST=0.0.0.0
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
ollama-setup:
|
||||||
|
build: ./backend
|
||||||
|
depends_on:
|
||||||
|
- ollama
|
||||||
|
environment:
|
||||||
|
OLLAMA_BASE_URL: http://ollama:11434
|
||||||
|
LLM_MODEL: llama3.2
|
||||||
|
volumes:
|
||||||
|
- ./backend:/app
|
||||||
|
command: python setup_ollama.py
|
||||||
|
restart: "no"
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
redis_data:
|
redis_data:
|
||||||
|
|
Loading…
Add table
Reference in a new issue