#!/usr/bin/env python3 """ Setup script to pull the default Ollama model on startup """ import os import sys import time import requests import json import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def wait_for_ollama(base_url: str, max_retries: int = 30, delay: int = 2) -> bool: """Wait for Ollama service to be ready""" for i in range(max_retries): try: response = requests.get(f"{base_url}/api/tags", timeout=5) if response.status_code == 200: logger.info("Ollama service is ready") return True except Exception as e: logger.info(f"Waiting for Ollama service... ({i+1}/{max_retries})") time.sleep(delay) logger.error("Ollama service is not ready after maximum retries") return False def check_model_exists(base_url: str, model: str) -> bool: """Check if a model exists in Ollama""" try: response = requests.get(f"{base_url}/api/tags", timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) for m in models: model_name = m.get('name', '') if model_name.startswith(model + ':') or model_name == model: logger.info(f"Model {model} already exists") return True return False except Exception as e: logger.error(f"Error checking models: {e}") return False def pull_model(base_url: str, model: str) -> bool: """Pull an Ollama model""" try: logger.info(f"Pulling model {model}...") payload = {"name": model} response = requests.post( f"{base_url}/api/pull", json=payload, timeout=1800, # 30 minutes timeout for model download stream=True ) if response.status_code == 200: # Stream the response to monitor progress for line in response.iter_lines(): if line: try: data = json.loads(line.decode('utf-8')) status = data.get('status', '') if 'pulling' in status.lower() or 'downloading' in status.lower(): logger.info(f"Ollama: {status}") elif data.get('error'): logger.error(f"Ollama pull error: {data.get('error')}") return False except json.JSONDecodeError: continue logger.info(f"Successfully pulled model {model}") return True else: logger.error(f"Failed to pull model {model}: HTTP {response.status_code}") logger.error(f"Response: {response.text}") return False except Exception as e: logger.error(f"Error pulling model {model}: {e}") return False def main(): """Main setup function""" base_url = os.getenv('OLLAMA_BASE_URL', 'http://ollama:11434') model = os.getenv('LLM_MODEL', 'llama3.2') logger.info(f"Setting up Ollama with model {model}") logger.info(f"Ollama URL: {base_url}") # Wait for Ollama service to be ready if not wait_for_ollama(base_url): logger.error("Ollama service is not available") sys.exit(1) # Check if model already exists if check_model_exists(base_url, model): logger.info(f"Model {model} is already available") sys.exit(0) # Pull the model if pull_model(base_url, model): logger.info(f"Setup completed successfully - model {model} is ready") sys.exit(0) else: logger.error(f"Failed to pull model {model}") sys.exit(1) if __name__ == "__main__": main()