add ollama to docker-compose for local model testing

This commit is contained in:
Brendan McDevitt 2025-07-10 21:32:15 -05:00
parent 3c120462ac
commit 08d6e33bbc
5 changed files with 873 additions and 1 deletions

View file

@ -94,6 +94,15 @@ class LLMClient:
elif self.provider == 'ollama': elif self.provider == 'ollama':
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
# Check if model is available, if not try to pull it
if not self._check_ollama_model_available(base_url, self.model):
logger.info(f"Model {self.model} not found, attempting to pull...")
if self._pull_ollama_model(base_url, self.model):
logger.info(f"Successfully pulled model {self.model}")
else:
logger.error(f"Failed to pull model {self.model}")
return
self.llm = Ollama( self.llm = Ollama(
model=self.model, model=self.model,
base_url=base_url, base_url=base_url,
@ -116,11 +125,20 @@ class LLMClient:
def get_provider_info(self) -> Dict[str, Any]: def get_provider_info(self) -> Dict[str, Any]:
"""Get information about the current provider and configuration.""" """Get information about the current provider and configuration."""
provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {}) provider_info = self.SUPPORTED_PROVIDERS.get(self.provider, {})
# For Ollama, get actually available models
available_models = provider_info.get('models', [])
if self.provider == 'ollama':
ollama_models = self._get_ollama_available_models()
if ollama_models:
available_models = ollama_models
return { return {
'provider': self.provider, 'provider': self.provider,
'model': self.model, 'model': self.model,
'available': self.is_available(), 'available': self.is_available(),
'supported_models': provider_info.get('models', []), 'supported_models': provider_info.get('models', []),
'available_models': available_models,
'env_key': provider_info.get('env_key', ''), 'env_key': provider_info.get('env_key', ''),
'api_key_configured': bool(os.getenv(provider_info.get('env_key', ''))) 'api_key_configured': bool(os.getenv(provider_info.get('env_key', '')))
} }
@ -396,3 +414,71 @@ Output ONLY the enhanced SIGMA rule in valid YAML format."""
self._initialize_llm() self._initialize_llm()
logger.info(f"Switched to provider: {provider} with model: {self.model}") logger.info(f"Switched to provider: {provider} with model: {self.model}")
def _check_ollama_model_available(self, base_url: str, model: str) -> bool:
"""Check if an Ollama model is available locally"""
try:
import requests
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
if m.get('name', '').startswith(model + ':') or m.get('name') == model:
return True
return False
except Exception as e:
logger.error(f"Error checking Ollama models: {e}")
return False
def _pull_ollama_model(self, base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
import requests
import json
# Use the pull API endpoint
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=300, # 5 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
if data.get('status'):
logger.info(f"Ollama pull progress: {data.get('status')}")
if data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
return False
except Exception as e:
logger.error(f"Error pulling Ollama model {model}: {e}")
return False
def _get_ollama_available_models(self) -> List[str]:
"""Get list of available Ollama models"""
try:
import requests
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
return [m.get('name', '') for m in models if m.get('name')]
return []
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
return []

View file

@ -1461,6 +1461,7 @@ async def sync_references(request: ReferenceSyncRequest, background_tasks: Backg
result = await client.bulk_sync_references( result = await client.bulk_sync_references(
batch_size=request.batch_size, batch_size=request.batch_size,
max_cves=request.max_cves, max_cves=request.max_cves,
force_resync=request.force_resync,
cancellation_flag=lambda: job_cancellation_flags.get(job_id, False) cancellation_flag=lambda: job_cancellation_flags.get(job_id, False)
) )
running_jobs[job_id]['result'] = result running_jobs[job_id]['result'] = result
@ -2088,6 +2089,60 @@ async def get_running_jobs(db: Session = Depends(get_db)):
logger.error(f"Error getting running jobs: {e}") logger.error(f"Error getting running jobs: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/ollama-pull-model")
async def pull_ollama_model(request: dict, background_tasks: BackgroundTasks):
"""Pull an Ollama model"""
try:
from llm_client import LLMClient
model = request.get('model')
if not model:
raise HTTPException(status_code=400, detail="Model name is required")
# Create a background task to pull the model
def pull_model_task():
try:
client = LLMClient(provider='ollama', model=model)
base_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
if client._pull_ollama_model(base_url, model):
logger.info(f"Successfully pulled Ollama model: {model}")
else:
logger.error(f"Failed to pull Ollama model: {model}")
except Exception as e:
logger.error(f"Error in model pull task: {e}")
background_tasks.add_task(pull_model_task)
return {
"message": f"Started pulling model {model}",
"status": "started",
"model": model
}
except Exception as e:
logger.error(f"Error starting model pull: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/ollama-models")
async def get_ollama_models():
"""Get available Ollama models"""
try:
from llm_client import LLMClient
client = LLMClient(provider='ollama')
available_models = client._get_ollama_available_models()
return {
"available_models": available_models,
"total_models": len(available_models),
"status": "success"
}
except Exception as e:
logger.error(f"Error getting Ollama models: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)

603
backend/reference_client.py Normal file
View file

@ -0,0 +1,603 @@
"""
Reference Data Extraction Client
Extracts and analyzes text content from CVE references and KEV records
"""
import aiohttp
import asyncio
import json
import logging
import re
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from urllib.parse import urlparse, urljoin
from sqlalchemy.orm import Session
from sqlalchemy import text, func
import hashlib
from bs4 import BeautifulSoup
import ssl
import certifi
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ReferenceClient:
"""Client for extracting and analyzing reference content from CVE and KEV records"""
def __init__(self, db_session: Session):
self.db_session = db_session
# Rate limiting
self.rate_limit_delay = 2.0 # 2 seconds between requests
self.last_request_time = 0
# Cache for processed URLs
self.url_cache = {}
self.cache_ttl = 86400 # 24 hours cache
# SSL context for secure requests
self.ssl_context = ssl.create_default_context(cafile=certifi.where())
# Allow self-signed certificates for some sites that might have issues
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
# Common headers to avoid being blocked
self.headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
'Accept-Language': 'en-US,en;q=0.5',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
'Sec-Fetch-Dest': 'document',
'Sec-Fetch-Mode': 'navigate',
'Sec-Fetch-Site': 'none'
}
# Supported reference types
self.reference_types = {
'security_advisory': ['security', 'advisory', 'bulletin', 'alert', 'cve', 'vulnerability'],
'patch': ['patch', 'fix', 'update', 'hotfix', 'security-update'],
'exploit': ['exploit', 'poc', 'proof-of-concept', 'github.com', 'exploit-db'],
'technical_analysis': ['analysis', 'research', 'technical', 'writeup', 'blog'],
'vendor_advisory': ['microsoft', 'apple', 'oracle', 'cisco', 'vmware', 'adobe'],
'cve_database': ['cve.mitre.org', 'nvd.nist.gov', 'cve.org']
}
async def _make_request(self, session: aiohttp.ClientSession, url: str) -> Optional[Tuple[str, str]]:
"""Make a rate-limited request to fetch URL content"""
try:
# Rate limiting
current_time = asyncio.get_event_loop().time()
time_since_last = current_time - self.last_request_time
if time_since_last < self.rate_limit_delay:
await asyncio.sleep(self.rate_limit_delay - time_since_last)
# Check cache first
url_hash = hashlib.md5(url.encode()).hexdigest()
if url_hash in self.url_cache:
cache_entry = self.url_cache[url_hash]
if datetime.now().timestamp() - cache_entry['timestamp'] < self.cache_ttl:
logger.info(f"Using cached content for {url}")
return cache_entry['content'], cache_entry['content_type']
async with session.get(url, headers=self.headers) as response:
self.last_request_time = asyncio.get_event_loop().time()
if response.status == 200:
content_type = response.headers.get('content-type', '').lower()
# Only process text content
if 'text/html' in content_type or 'text/plain' in content_type or 'application/json' in content_type:
try:
content = await response.text(encoding='utf-8', errors='ignore')
except UnicodeDecodeError:
# Fallback to binary content if text decode fails
content_bytes = await response.read()
content = content_bytes.decode('utf-8', errors='ignore')
# Cache the result
self.url_cache[url_hash] = {
'content': content,
'content_type': content_type,
'timestamp': datetime.now().timestamp()
}
return content, content_type
else:
logger.warning(f"Unsupported content type {content_type} for {url}")
return None, None
elif response.status in [301, 302, 303, 307, 308]:
logger.info(f"Redirect response {response.status} for {url}")
return None, None
else:
logger.warning(f"Request failed: {response.status} for {url}")
return None, None
except aiohttp.ClientError as e:
logger.warning(f"Client error fetching {url}: {e}")
return None, None
except asyncio.TimeoutError:
logger.warning(f"Timeout fetching {url}")
return None, None
except Exception as e:
logger.error(f"Unexpected error fetching {url}: {e}")
return None, None
def _extract_text_from_html(self, html_content: str) -> str:
"""Extract meaningful text from HTML content"""
try:
soup = BeautifulSoup(html_content, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style"]):
script.decompose()
# Extract text from common content areas
content_selectors = [
'article', 'main', '.content', '#content',
'.post-content', '.entry-content', 'section'
]
text_content = ""
for selector in content_selectors:
elements = soup.select(selector)
if elements:
for element in elements:
text_content += element.get_text(separator=' ', strip=True) + '\n'
break
# If no structured content found, get all text
if not text_content.strip():
text_content = soup.get_text(separator=' ', strip=True)
# Clean up the text
text_content = re.sub(r'\s+', ' ', text_content)
text_content = text_content.strip()
return text_content
except Exception as e:
logger.error(f"Error extracting text from HTML: {e}")
return ""
def _analyze_reference_content(self, url: str, content: str) -> Dict[str, Any]:
"""Analyze reference content to extract security-relevant information"""
analysis = {
'url': url,
'content_length': len(content),
'reference_type': 'unknown',
'security_keywords': [],
'technical_indicators': [],
'patch_information': [],
'exploit_indicators': [],
'cve_mentions': [],
'severity_indicators': [],
'mitigation_steps': [],
'affected_products': [],
'attack_vectors': [],
'confidence_score': 0
}
if not content:
return analysis
content_lower = content.lower()
# Classify reference type
domain = urlparse(url).netloc.lower()
for ref_type, keywords in self.reference_types.items():
if any(keyword in domain or keyword in content_lower for keyword in keywords):
analysis['reference_type'] = ref_type
break
# Extract CVE mentions
cve_pattern = r'(CVE-\d{4}-\d{4,7})'
cve_matches = re.findall(cve_pattern, content, re.IGNORECASE)
analysis['cve_mentions'] = list(set(cve_matches))
# Security keywords
security_keywords = [
'vulnerability', 'exploit', 'attack', 'malware', 'backdoor',
'privilege escalation', 'remote code execution', 'rce',
'sql injection', 'xss', 'csrf', 'buffer overflow',
'authentication bypass', 'authorization', 'injection',
'denial of service', 'dos', 'ddos', 'ransomware'
]
for keyword in security_keywords:
if keyword in content_lower:
analysis['security_keywords'].append(keyword)
# Technical indicators
technical_patterns = [
r'\b(function|method|class|variable)\s+\w+',
r'\b(file|directory|path|folder)\s+[^\s]+',
r'\b(port|service|protocol)\s+\d+',
r'\b(registry|key|value)\s+[^\s]+',
r'\b(process|executable|binary)\s+[^\s]+',
r'\b(dll|exe|bat|ps1|sh|py|jar)\b',
r'\b(http|https|ftp|smb|tcp|udp)://[^\s]+',
r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
]
for pattern in technical_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
if matches:
analysis['technical_indicators'].extend(matches[:10]) # Limit to 10 per pattern
# Patch information
patch_keywords = [
'patch', 'fix', 'update', 'hotfix', 'security update',
'kb\d+', 'ms\d+-\d+', 'version \d+\.\d+',
'download', 'install', 'upgrade'
]
for keyword in patch_keywords:
if re.search(keyword, content_lower):
analysis['patch_information'].append(keyword)
# Exploit indicators
exploit_keywords = [
'proof of concept', 'poc', 'exploit code', 'payload',
'shellcode', 'reverse shell', 'metasploit', 'nmap',
'vulnerability assessment', 'penetration test', 'bypass'
]
for keyword in exploit_keywords:
if keyword in content_lower:
analysis['exploit_indicators'].append(keyword)
# Severity indicators
severity_patterns = [
r'\b(critical|high|medium|low)\s+(severity|risk|priority)',
r'\b(cvss|score)\s*[:=]?\s*(\d+\.\d+|\d+)',
r'\b(exploitability|impact)\s*[:=]?\s*(high|medium|low)'
]
for pattern in severity_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
if matches:
analysis['severity_indicators'].extend([' '.join(match) if isinstance(match, tuple) else match for match in matches])
# Mitigation steps
mitigation_keywords = [
'mitigation', 'workaround', 'prevention', 'remediation',
'disable', 'block', 'restrict', 'configure', 'setting'
]
# Find sentences containing mitigation keywords
sentences = re.split(r'[.!?]+', content)
for sentence in sentences:
if any(keyword in sentence.lower() for keyword in mitigation_keywords):
if len(sentence.strip()) > 20: # Avoid very short sentences
analysis['mitigation_steps'].append(sentence.strip()[:200]) # Limit length
# Calculate confidence score
score = 0
score += min(len(analysis['security_keywords']) * 5, 25)
score += min(len(analysis['technical_indicators']) * 2, 20)
score += min(len(analysis['cve_mentions']) * 10, 30)
score += min(len(analysis['patch_information']) * 3, 15)
score += min(len(analysis['exploit_indicators']) * 4, 20)
if analysis['reference_type'] != 'unknown':
score += 10
analysis['confidence_score'] = min(score, 100)
# Clean up and deduplicate
for key in ['security_keywords', 'technical_indicators', 'patch_information',
'exploit_indicators', 'severity_indicators', 'mitigation_steps']:
analysis[key] = list(set(analysis[key]))[:10] # Limit to 10 items each
return analysis
async def extract_reference_content(self, url: str) -> Optional[Dict[str, Any]]:
"""Extract and analyze content from a single reference URL"""
try:
connector = aiohttp.TCPConnector(
ssl=self.ssl_context,
limit=100,
limit_per_host=30,
ttl_dns_cache=300,
use_dns_cache=True,
)
timeout = aiohttp.ClientTimeout(total=60, connect=30)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
content, content_type = await self._make_request(session, url)
if not content:
return None
# Extract text from HTML if needed
if content_type and 'text/html' in content_type:
text_content = self._extract_text_from_html(content)
else:
text_content = content
# Analyze the content
analysis = self._analyze_reference_content(url, text_content)
# Add metadata
analysis.update({
'extracted_at': datetime.utcnow().isoformat(),
'content_type': content_type,
'text_length': len(text_content),
'source': 'reference_extraction'
})
return analysis
except Exception as e:
logger.error(f"Error extracting reference content from {url}: {e}")
return None
async def sync_cve_references(self, cve_id: str) -> Dict[str, Any]:
"""Sync reference data for a specific CVE"""
from main import CVE, SigmaRule
# Get existing CVE
cve = self.db_session.query(CVE).filter(CVE.cve_id == cve_id).first()
if not cve:
logger.warning(f"CVE {cve_id} not found in database")
return {"error": "CVE not found"}
if not cve.reference_urls:
logger.info(f"No reference URLs found for CVE {cve_id}")
return {"cve_id": cve_id, "references_processed": 0}
logger.info(f"Processing {len(cve.reference_urls)} references for CVE {cve_id}")
processed_references = []
successful_extractions = 0
for url in cve.reference_urls:
try:
# Extract reference content
ref_analysis = await self.extract_reference_content(url)
if ref_analysis:
processed_references.append(ref_analysis)
successful_extractions += 1
logger.info(f"Successfully extracted content from {url}")
else:
logger.warning(f"Failed to extract content from {url}")
# Small delay between requests
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error processing reference {url}: {e}")
# Update CVE with reference data
cve.reference_data = {
'reference_analysis': {
'references': processed_references,
'total_references': len(cve.reference_urls),
'successful_extractions': successful_extractions,
'extraction_rate': successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0,
'extracted_at': datetime.utcnow().isoformat(),
'source': 'reference_extraction'
}
}
cve.reference_sync_status = 'completed' if successful_extractions > 0 else 'failed'
cve.reference_last_synced = datetime.utcnow()
cve.updated_at = datetime.utcnow()
# Update SIGMA rule with reference data
sigma_rule = self.db_session.query(SigmaRule).filter(
SigmaRule.cve_id == cve_id
).first()
if sigma_rule:
# Aggregate indicators from all references
aggregated_indicators = {
'security_keywords': [],
'technical_indicators': [],
'exploit_indicators': [],
'patch_information': [],
'attack_vectors': [],
'mitigation_steps': []
}
for ref in processed_references:
for key in aggregated_indicators.keys():
if key in ref:
aggregated_indicators[key].extend(ref[key])
# Deduplicate
for key in aggregated_indicators:
aggregated_indicators[key] = list(set(aggregated_indicators[key]))
# Update rule with reference data
if not sigma_rule.nomi_sec_data:
sigma_rule.nomi_sec_data = {}
sigma_rule.nomi_sec_data['reference_analysis'] = {
'aggregated_indicators': aggregated_indicators,
'reference_count': len(processed_references),
'high_confidence_references': len([r for r in processed_references if r.get('confidence_score', 0) > 70]),
'reference_types': list(set([r.get('reference_type') for r in processed_references if r.get('reference_type') != 'unknown'])),
'source': 'reference_extraction'
}
# Update exploit indicators
if sigma_rule.exploit_indicators:
existing_indicators = json.loads(sigma_rule.exploit_indicators)
else:
existing_indicators = {}
for key, values in aggregated_indicators.items():
if key not in existing_indicators:
existing_indicators[key] = []
existing_indicators[key].extend(values)
existing_indicators[key] = list(set(existing_indicators[key]))
sigma_rule.exploit_indicators = json.dumps(existing_indicators)
sigma_rule.updated_at = datetime.utcnow()
self.db_session.commit()
logger.info(f"Successfully synchronized reference data for {cve_id}")
return {
"cve_id": cve_id,
"references_processed": len(processed_references),
"successful_extractions": successful_extractions,
"extraction_rate": successful_extractions / len(cve.reference_urls) if cve.reference_urls else 0,
"high_confidence_references": len([r for r in processed_references if r.get('confidence_score', 0) > 70]),
"source": "reference_extraction"
}
async def bulk_sync_references(self, batch_size: int = 50, max_cves: int = None,
force_resync: bool = False, cancellation_flag: Optional[callable] = None) -> Dict[str, Any]:
"""Bulk synchronize reference data for multiple CVEs"""
from main import CVE, BulkProcessingJob
# Create bulk processing job
job = BulkProcessingJob(
job_type='reference_sync',
status='running',
started_at=datetime.utcnow(),
job_metadata={'batch_size': batch_size, 'max_cves': max_cves}
)
self.db_session.add(job)
self.db_session.commit()
total_processed = 0
total_references = 0
successful_extractions = 0
try:
# Get CVEs that have reference URLs but no reference analysis
query = self.db_session.query(CVE).filter(
CVE.reference_urls.isnot(None),
func.array_length(CVE.reference_urls, 1) > 0
)
# Filter out CVEs that already have reference analysis (unless force_resync is True)
if not force_resync:
query = query.filter(
CVE.reference_sync_status != 'completed'
)
if max_cves:
cves = query.limit(max_cves).all()
else:
cves = query.all()
job.total_items = len(cves)
self.db_session.commit()
logger.info(f"Starting bulk reference sync for {len(cves)} CVEs")
# Process in batches
for i in range(0, len(cves), batch_size):
# Check for cancellation
if cancellation_flag and cancellation_flag():
logger.info("Bulk reference sync cancelled by user")
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
break
batch = cves[i:i + batch_size]
for cve in batch:
# Check for cancellation
if cancellation_flag and cancellation_flag():
logger.info("Bulk reference sync cancelled by user")
job.status = 'cancelled'
job.cancelled_at = datetime.utcnow()
job.error_message = "Job cancelled by user"
break
try:
result = await self.sync_cve_references(cve.cve_id)
if "error" not in result:
total_processed += 1
total_references += result.get("references_processed", 0)
successful_extractions += result.get("successful_extractions", 0)
else:
job.failed_items += 1
job.processed_items += 1
# Longer delay for reference extraction to be respectful
await asyncio.sleep(2)
except Exception as e:
logger.error(f"Error processing references for {cve.cve_id}: {e}")
job.failed_items += 1
# Break out of outer loop if cancelled
if job.status == 'cancelled':
break
# Commit after each batch
self.db_session.commit()
logger.info(f"Processed reference batch {i//batch_size + 1}/{(len(cves) + batch_size - 1)//batch_size}")
# Update job status
if job.status != 'cancelled':
job.status = 'completed'
job.completed_at = datetime.utcnow()
job.job_metadata.update({
'total_processed': total_processed,
'total_references': total_references,
'successful_extractions': successful_extractions,
'extraction_rate': successful_extractions / total_references if total_references > 0 else 0,
'source': 'reference_extraction'
})
except Exception as e:
job.status = 'failed'
job.error_message = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Bulk reference sync job failed: {e}")
finally:
self.db_session.commit()
return {
'job_id': str(job.id),
'status': job.status,
'total_processed': total_processed,
'total_references': total_references,
'successful_extractions': successful_extractions,
'extraction_rate': successful_extractions / total_references if total_references > 0 else 0,
'source': 'reference_extraction'
}
async def get_reference_sync_status(self) -> Dict[str, Any]:
"""Get reference synchronization status"""
from main import CVE
# Count CVEs with reference URLs
total_cves = self.db_session.query(CVE).count()
cves_with_refs = self.db_session.query(CVE).filter(
CVE.reference_urls.isnot(None),
func.array_length(CVE.reference_urls, 1) > 0
).count()
# Count CVEs with reference analysis
cves_with_analysis = self.db_session.query(CVE).filter(
CVE.reference_sync_status == 'completed'
).count()
return {
'total_cves': total_cves,
'cves_with_references': cves_with_refs,
'cves_with_analysis': cves_with_analysis,
'reference_coverage': (cves_with_refs / total_cves * 100) if total_cves > 0 else 0,
'analysis_coverage': (cves_with_analysis / cves_with_refs * 100) if cves_with_refs > 0 else 0,
'sync_status': 'active' if cves_with_analysis > 0 else 'pending',
'source': 'reference_extraction'
}

114
backend/setup_ollama.py Executable file
View file

@ -0,0 +1,114 @@
#!/usr/bin/env python3
"""
Setup script to pull the default Ollama model on startup
"""
import os
import sys
import time
import requests
import json
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
"""Wait for Ollama service to be ready"""
for i in range(max_retries):
try:
response = requests.get(f"{base_url}/api/tags", timeout=5)
if response.status_code == 200:
logger.info("Ollama service is ready")
return True
except Exception as e:
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
time.sleep(delay)
logger.error("Ollama service is not ready after maximum retries")
return False
def check_model_exists(base_url: str, model: str) -> bool:
"""Check if a model exists in Ollama"""
try:
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
model_name = m.get('name', '')
if model_name.startswith(model + ':') or model_name == model:
logger.info(f"Model {model} already exists")
return True
return False
except Exception as e:
logger.error(f"Error checking models: {e}")
return False
def pull_model(base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
logger.info(f"Pulling model {model}...")
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=1800, # 30 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
status = data.get('status', '')
if 'pulling' in status.lower() or 'downloading' in status.lower():
logger.info(f"Ollama: {status}")
elif data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
logger.info(f"Successfully pulled model {model}")
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
logger.error(f"Response: {response.text}")
return False
except Exception as e:
logger.error(f"Error pulling model {model}: {e}")
return False
def main():
"""Main setup function"""
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
model = os.getenv('LLM_MODEL', 'llama3.2')
logger.info(f"Setting up Ollama with model {model}")
logger.info(f"Ollama URL: {base_url}")
# Wait for Ollama service to be ready
if not wait_for_ollama(base_url):
logger.error("Ollama service is not available")
sys.exit(1)
# Check if model already exists
if check_model_exists(base_url, model):
logger.info(f"Model {model} is already available")
sys.exit(0)
# Pull the model
if pull_model(base_url, model):
logger.info(f"Setup completed successfully - model {model} is ready")
sys.exit(0)
else:
logger.error(f"Failed to pull model {model}")
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -32,6 +32,8 @@ services:
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
ollama-setup:
condition: service_completed_successfully
volumes: volumes:
- ./backend:/app - ./backend:/app
- ./github_poc_collector:/github_poc_collector - ./github_poc_collector:/github_poc_collector
@ -67,6 +69,18 @@ services:
- OLLAMA_HOST=0.0.0.0 - OLLAMA_HOST=0.0.0.0
restart: unless-stopped restart: unless-stopped
ollama-setup:
build: ./backend
depends_on:
- ollama
environment:
OLLAMA_BASE_URL: http://ollama:11434
LLM_MODEL: llama3.2
volumes:
- ./backend:/app
command: python setup_ollama.py
restart: "no"
volumes: volumes:
postgres_data: postgres_data:
redis_data: redis_data: