auto_sigma_rule_generator/backend/setup_ollama.py

114 lines
No EOL
3.8 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Setup script to pull the default Ollama model on startup
"""
import os
import sys
import time
import requests
import json
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool:
"""Wait for Ollama service to be ready"""
for i in range(max_retries):
try:
response = requests.get(f"{base_url}/api/tags", timeout=5)
if response.status_code == 200:
logger.info("Ollama service is ready")
return True
except Exception as e:
logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})")
time.sleep(delay)
logger.error("Ollama service is not ready after maximum retries")
return False
def check_model_exists(base_url: str, model: str) -> bool:
"""Check if a model exists in Ollama"""
try:
response = requests.get(f"{base_url}/api/tags", timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
for m in models:
model_name = m.get('name', '')
if model_name.startswith(model + ':') or model_name == model:
logger.info(f"Model {model} already exists")
return True
return False
except Exception as e:
logger.error(f"Error checking models: {e}")
return False
def pull_model(base_url: str, model: str) -> bool:
"""Pull an Ollama model"""
try:
logger.info(f"Pulling model {model}...")
payload = {"name": model}
response = requests.post(
f"{base_url}/api/pull",
json=payload,
timeout=1800, # 30 minutes timeout for model download
stream=True
)
if response.status_code == 200:
# Stream the response to monitor progress
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode('utf-8'))
status = data.get('status', '')
if 'pulling' in status.lower() or 'downloading' in status.lower():
logger.info(f"Ollama: {status}")
elif data.get('error'):
logger.error(f"Ollama pull error: {data.get('error')}")
return False
except json.JSONDecodeError:
continue
logger.info(f"Successfully pulled model {model}")
return True
else:
logger.error(f"Failed to pull model {model}: HTTP {response.status_code}")
logger.error(f"Response: {response.text}")
return False
except Exception as e:
logger.error(f"Error pulling model {model}: {e}")
return False
def main():
"""Main setup function"""
base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434')
model = os.getenv('LLM_MODEL', 'llama3.2')
logger.info(f"Setting up Ollama with model {model}")
logger.info(f"Ollama URL: {base_url}")
# Wait for Ollama service to be ready
if not wait_for_ollama(base_url):
logger.error("Ollama service is not available")
sys.exit(1)
# Check if model already exists
if check_model_exists(base_url, model):
logger.info(f"Model {model} is already available")
sys.exit(0)
# Pull the model
if pull_model(base_url, model):
logger.info(f"Setup completed successfully - model {model} is ready")
sys.exit(0)
else:
logger.error(f"Failed to pull model {model}")
sys.exit(1)
if __name__ == "__main__":
main()