"""
Model Registry for GLEAM-AI.

This module provides a registry system for managing different model versions,
configurations, and caching loaded models and statistics.
"""

import yaml
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
import logging
from functools import lru_cache

logger = logging.getLogger(__name__)


class ModelRegistry:
    """
    Registry for managing GLEAM-AI models and their configurations.
    
    This class handles:
    - Loading model configurations from config files
    - Managing multiple model versions
    - Caching loaded models and statistics
    - Providing default model information
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the model registry.
        
        Args:
            config_path: Path to the configuration file. If None, uses default "config.yaml"
        """
        self.config_path = Path(config_path) if config_path else Path("config.yaml")
        self.config = self._load_config()
        self._model_cache = {}
        self._stats_cache = {}
        
        # Validate configuration
        self._validate_config()
    
    def _load_config(self) -> Dict[str, Any]:
        """Load configuration from YAML file."""
        if not self.config_path.exists():
            raise FileNotFoundError(
                f"Configuration file not found: {self.config_path}. "
                "Please ensure the config.yaml file exists in the current directory."
            )
        
        try:
            with open(self.config_path, 'r') as f:
                config = yaml.safe_load(f)
            return config
        except yaml.YAMLError as e:
            raise ValueError(f"Error parsing configuration file: {e}")
        except Exception as e:
            raise RuntimeError(f"Error loading configuration file: {e}")
    
    def _validate_config(self) -> None:
        """Validate the configuration structure."""
        if "model_registry" not in self.config:
            raise ValueError(
                "Configuration file is missing 'model_registry' section. "
                "Please add model registry information to your config.yaml file."
            )
        
        registry = self.config["model_registry"]
        if "default_model" not in registry:
            raise ValueError("Configuration is missing 'default_model' specification.")
        
        if "models" not in registry:
            raise ValueError("Configuration is missing 'models' section.")
        
        default_model = registry["default_model"]
        if default_model not in registry["models"]:
            raise ValueError(f"Default model '{default_model}' not found in models section.")
    
    def get_available_models(self) -> Dict[str, Dict[str, Any]]:
        """Get all available model configurations."""
        return self.config["model_registry"]["models"].copy()
    
    def get_default_model_name(self) -> str:
        """Get the name of the default model."""
        return self.config["model_registry"]["default_model"]
    
    def get_model_info(self, model_name: Optional[str] = None) -> Dict[str, Any]:
        """
        Get information about a specific model.
        
        Args:
            model_name: Name of the model. If None, returns default model info.
            
        Returns:
            Dictionary containing model information
        """
        if model_name is None:
            model_name = self.get_default_model_name()
        
        models = self.config["model_registry"]["models"]
        if model_name not in models:
            available = list(models.keys())
            raise ValueError(
                f"Model '{model_name}' not found. Available models: {available}"
            )
        
        return models[model_name].copy()
    
    def get_model_paths(self, model_name: Optional[str] = None) -> Tuple[str, str]:
        """
        Get checkpoint and statistics paths for a model.
        
        Args:
            model_name: Name of the model. If None, returns default model paths.
            
        Returns:
            Tuple of (checkpoint_path, y_stats_path)
        """
        model_info = self.get_model_info(model_name)
        
        checkpoint_path = model_info.get("checkpoint_path")
        y_stats_path = model_info.get("y_stats_path")
        
        if not checkpoint_path:
            raise ValueError(f"Model '{model_name}' is missing checkpoint_path")
        if not y_stats_path:
            raise ValueError(f"Model '{model_name}' is missing y_stats_path")
        
        return checkpoint_path, y_stats_path
    
    @lru_cache(maxsize=10)
    def load_y_stats(self, y_stats_path: str) -> Tuple[np.ndarray, np.ndarray]:
        """
        Load y statistics from file with caching.
        
        Args:
            y_stats_path: Path to the y_stats parquet file
            
        Returns:
            Tuple of (y_mean, y_std) arrays
        """
        if not Path(y_stats_path).exists():
            raise FileNotFoundError(
                f"Y statistics file not found: {y_stats_path}. "
                "Please ensure the file exists and the path is correct."
            )
        
        try:
            y_stats = pd.read_parquet(y_stats_path)
            y_mean = y_stats["y_mean"].values
            y_std = y_stats["y_std"].values
            
            logger.info(f"Loaded y statistics from {y_stats_path}")
            return y_mean, y_std
            
        except Exception as e:
            raise RuntimeError(f"Error loading y statistics from {y_stats_path}: {e}")
    
    def clear_cache(self) -> None:
        """Clear all cached models and statistics."""
        self._model_cache.clear()
        self._stats_cache.clear()
        self.load_y_stats.cache_clear()
        logger.info("Model registry cache cleared")
    
    def get_cache_info(self) -> Dict[str, int]:
        """Get information about cache usage."""
        return {
            "model_cache_size": len(self._model_cache),
            "stats_cache_size": len(self._stats_cache),
            "y_stats_cache_size": self.load_y_stats.cache_info()["currsize"]
        }
