auto_sigma_rule_generator/backend/setup_ollama.py
bpmcdevitt 49963338d3 Add Celery dependencies and enhance bulk seeder
- Add Celery, Flower, and related dependencies to requirements.txt
- Update bulk_seeder.py with progress callback support for Celery integration
- Clean up finetuned model dependencies (now served through Ollama)
- Update setup_ollama.py for enhanced configuration

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-17 18:59:07 -05:00

159 lines
No EOL
5.3 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Setup script to pull the default Ollama model on startup
"""
import os
import sys
import time
import requests
import json
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
"""Wait for Ollama service to be ready"""
for i in range(max_retries):
try:
response = requests.get(f"{base_url}/api/tags", timeout=5)
if response.status_code == 200:
logger.info("Ollama service is ready")
return True
except Exception as e:
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
time.sleep(delay)
logger.error("Ollama service is not ready after maximum retries")
return False
def check_model_exists(base_url: str, model: str) -> bool:
"""Check if a model exists in Ollama"""
try:
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
model_name = m.get('name', '')
if model_name.startswith(model + ':') or model_name == model:
logger.info(f"Model {model} already exists")
return True
return False
except Exception as e:
logger.error(f"Error checking models: {e}")
return False
def pull_model(base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
logger.info(f"Pulling model {model}...")
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=1800, # 30 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
status = data.get('status', '')
if 'pulling' in status.lower() or 'downloading' in status.lower():
logger.info(f"Ollama: {status}")
elif data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
logger.info(f"Successfully pulled model {model}")
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
logger.error(f"Response: {response.text}")
return False
except Exception as e:
logger.error(f"Error pulling model {model}: {e}")
return False
def setup_custom_sigma_model(base_url: str):
"""Setup custom SIGMA model if fine-tuned model exists."""
model_path = "/app/models/sigma_llama_finetuned"
if not os.path.exists(model_path):
logger.info("Fine-tuned model not found, skipping custom model setup")
return False
logger.info("Fine-tuned model found, setting up custom Ollama model...")
# Create a simple Modelfile for the custom model
modelfile_content = f"""FROM llama3.2
TEMPLATE \"\"\"### Instruction:
{{{{ .Prompt }}}}
### Response:
\"\"\"
PARAMETER temperature 0.1
PARAMETER top_p 0.9
PARAMETER stop "### Instruction:"
PARAMETER stop "### Response:"
SYSTEM \"\"\"You are a cybersecurity expert specializing in SIGMA rule creation. Generate valid SIGMA rules in YAML format based on the provided CVE and exploit information.\"\"\"
"""
try:
# Write Modelfile
with open("/tmp/Modelfile.sigma", "w") as f:
f.write(modelfile_content)
# Create custom model (this would need ollama CLI in the container)
# For now, just log that we would create it
logger.info("Custom SIGMA model configuration prepared")
return True
except Exception as e:
logger.error(f"Error setting up custom model: {e}")
return False
def main():
"""Main setup function"""
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
model = os.getenv('LLM_MODEL', 'llama3.2')
logger.info(f"Setting up Ollama with model {model}")
logger.info(f"Ollama URL: {base_url}")
# Wait for Ollama service to be ready
if not wait_for_ollama(base_url):
logger.error("Ollama service is not available")
sys.exit(1)
# Check if model already exists
if check_model_exists(base_url, model):
logger.info(f"Model {model} is already available")
else:
# Pull the base model
if pull_model(base_url, model):
logger.info(f"Successfully pulled model {model}")
else:
logger.error(f"Failed to pull model {model}")
sys.exit(1)
# Setup custom SIGMA model if fine-tuned model exists
setup_custom_sigma_model(base_url)
logger.info("Setup completed successfully")
sys.exit(0)
if __name__ == "__main__":
main()