#!/usr/bin/env python3
"""
Usage:
    from utils.model_loader import ModelLoader
    
    loader = ModelLoader()
    model, tokenizer = loader.load_language_model("llama2-7b")
    vision_model = loader.load_vision_model("resnet18", num_classes=10)
"""

import os
import sys
import logging
from pathlib import Path
from typing import Dict, Tuple, Optional, Union, Any
import yaml

import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
    AutoConfig, BitsAndBytesConfig
)
import torchvision.models as tv_models
from torchvision import transforms
from huggingface_hub import login


class ModelLoader:
    """Centralized model loading for OFMU experiments."""
    
    def __init__(self, config_path: Optional[str] = None, cache_dir: Optional[str] = None):
        self.cache_dir = Path(cache_dir) if cache_dir else Path("./cache/models")
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        self.setup_logging()
        self.load_model_configs(config_path)
        self.setup_huggingface_auth()
    
    def setup_logging(self):
        """Setup logging configuration."""
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def load_model_configs(self, config_path: Optional[str] = None):
        """Load model configurations."""
        if config_path and Path(config_path).exists():
            config_file = Path(config_path)
        else:
            config_file = Path(__file__).parent.parent / "config" / "models.yaml"
        
        if config_file.exists():
            with open(config_file, 'r') as f:
                self.config = yaml.safe_load(f)
        else:
            self.logger.warning("Model config not found, using defaults")
            self.config = self.get_default_config()
    
    def get_default_config(self) -> Dict:
        """Get default model configuration."""
        return {
            "language_models": {
                "llama2-7b": {
                    "model_name": "meta-llama/Llama-2-7b-chat-hf",
                    "tokenizer_name": "meta-llama/Llama-2-7b-chat-hf",
                    "max_length": 2048,
                    "batch_size": 4,
                    "dtype": "float16",
                    "device_map": "auto",
                    "use_auth_token": True
                },
                

            },
            "vision_models": {
                "resnet18": {
                    "model_name": "microsoft/resnet-18",
                    "num_classes": 10,
                    "pretrained": True
                }
            }
        }
    
    def setup_huggingface_auth(self):
        """Setup Hugging Face authentication."""
        hf_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
        if hf_token:
            try:
                login(token=hf_token)
                self.logger.info("Hugging Face authentication successful")
            except Exception as e:
                self.logger.warning(f"Hugging Face authentication failed: {e}")
        else:
            self.logger.info("No Hugging Face token found. Some models may not be accessible.")
    
    def load_language_model(
        self, 
        model_key: str, 
        custom_config: Optional[Dict] = None,
        load_mode: str = "default"
    ) -> Tuple[Any, Any]:
        """
        Load language model and tokenizer from Hugging Face.
        
        Args:
            model_key: Model key from config (e.g., 'llama2-7b')
            custom_config: Optional custom configuration overrides
            load_mode: Loading mode ('default', 'low_memory', 'high_performance')
            
        Returns:
            Tuple of (model, tokenizer)
        """
        if model_key not in self.config["language_models"]:
            raise ValueError(f"Unknown model key: {model_key}")
        
        model_config = self.config["language_models"][model_key].copy()
        
        # Apply custom config overrides
        if custom_config:
            model_config.update(custom_config)
        
        # Apply loading mode configurations
        if load_mode in self.config.get("loading_configs", {}):
            mode_config = self.config["loading_configs"][load_mode]
            for key, value in mode_config.items():
                if key == "batch_size_multiplier":
                    model_config["batch_size"] = int(model_config["batch_size"] * value)
                else:
                    model_config[key] = value
        
        self.logger.info(f"Loading model: {model_config['model_name']}")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_config["tokenizer_name"],
            cache_dir=str(self.cache_dir),
            trust_remote_code=model_config.get("trust_remote_code", False),
            use_auth_token=model_config.get("use_auth_token", False)
        )
        
        # Set padding token if not present
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        
        # Prepare model loading arguments
        model_kwargs = {
            "cache_dir": str(self.cache_dir),
            "trust_remote_code": model_config.get("trust_remote_code", False),
            "use_auth_token": model_config.get("use_auth_token", False)
        }
        
        # Handle different data types
        if model_config.get("dtype") == "float16":
            model_kwargs["torch_dtype"] = torch.float16
        elif model_config.get("dtype") == "bfloat16":
            model_kwargs["torch_dtype"] = torch.bfloat16
        elif model_config.get("dtype") == "float32":
            model_kwargs["torch_dtype"] = torch.float32
        
        # Handle device mapping
        if model_config.get("device_map"):
            model_kwargs["device_map"] = model_config["device_map"]
        
        # Handle quantization
        if model_config.get("load_in_8bit"):
            model_kwargs["load_in_8bit"] = True
            model_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False
            )
        elif model_config.get("load_in_4bit"):
            model_kwargs["load_in_4bit"] = True
            model_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True
            )
        
        # Load model
        try:
            model = AutoModelForCausalLM.from_pretrained(
                model_config["model_name"],
                **model_kwargs
            )
            
            # Enable gradient checkpointing if specified
            if model_config.get("gradient_checkpointing"):
                model.gradient_checkpointing_enable()
            
            self.logger.info(f"Successfully loaded {model_config['model_name']}")
            
        except Exception as e:
            self.logger.error(f"Failed to load model {model_config['model_name']}: {e}")
            raise
        
        return model, tokenizer
    
    def load_vision_model(
        self,
        model_key: str,
        num_classes: Optional[int] = None,
        pretrained: bool = True,
        custom_config: Optional[Dict] = None
    ) -> nn.Module:
        """
        Load vision model for CIFAR experiments.
        
        Args:
            model_key: Model key from config (e.g., 'resnet18')
            num_classes: Number of output classes
            pretrained: Whether to use pretrained weights
            custom_config: Optional custom configuration overrides
            
        Returns:
            Vision model
        """
        if model_key not in self.config["vision_models"]:
            raise ValueError(f"Unknown vision model key: {model_key}")
        
        model_config = self.config["vision_models"][model_key].copy()
        
        # Apply custom config overrides
        if custom_config:
            model_config.update(custom_config)
        
        if num_classes is not None:
            model_config["num_classes"] = num_classes
        
        self.logger.info(f"Loading vision model: {model_key}")
        
        # Load model based on type
        if model_key.startswith("resnet"):
            if model_key == "resnet18":
                model = tv_models.resnet18(pretrained=pretrained)
                # Modify first conv layer for CIFAR (32x32 instead of 224x224)
                model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                model.maxpool = nn.Identity()  # Remove maxpool for small images
            elif model_key == "resnet50":
                model = tv_models.resnet50(pretrained=pretrained)
                model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                model.maxpool = nn.Identity()
            else:
                raise ValueError(f"Unsupported ResNet variant: {model_key}")
            
            # Modify final layer for target number of classes
            if model_config["num_classes"] != 1000:  # ImageNet default
                num_features = model.fc.in_features
                model.fc = nn.Linear(num_features, model_config["num_classes"])
        
        else:
            raise ValueError(f"Unsupported vision model: {model_key}")
        
        self.logger.info(f"Successfully loaded vision model: {model_key}")
        return model
    
    def get_vision_transforms(
        self,
        model_key: str,
        train: bool = True,
        custom_config: Optional[Dict] = None
    ) -> transforms.Compose:
        """
        Get vision transforms for the specified model.
        
        Args:
            model_key: Model key from config
            train: Whether transforms are for training (includes augmentation)
            custom_config: Optional custom transform configuration
            
        Returns:
            Composed transforms
        """
        if model_key not in self.config["vision_models"]:
            raise ValueError(f"Unknown vision model key: {model_key}")
        
        model_config = self.config["vision_models"][model_key]
       
        if model_key.startswith("resnet"):
            if train:
                transform_list = [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(0.5),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ]
            else:
                transform_list = [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ]
        
        # Apply custom transforms if provided
        if custom_config and "transforms" in custom_config:
            # Override with custom transforms
            transform_list = custom_config["transforms"]
        
        return transforms.Compose(transform_list)
    
    def get_model_info(self, model_key: str, model_type: str = "language") -> Dict:
        """
        Get information about a model.
        
        Args:
            model_key: Model key from config
            model_type: Type of model ('language' or 'vision')
            
        Returns:
            Model information dictionary
        """
        config_key = f"{model_type}_models"
        if config_key not in self.config or model_key not in self.config[config_key]:
            raise ValueError(f"Unknown {model_type} model key: {model_key}")
        
        return self.config[config_key][model_key].copy()
    
    def list_available_models(self) -> Dict[str, List[str]]:
        """List all available models."""
        return {
            "language_models": list(self.config.get("language_models", {}).keys()),
            "vision_models": list(self.config.get("vision_models", {}).keys())
        }
    
    def estimate_memory_usage(self, model_key: str, model_type: str = "language") -> Dict[str, str]:
        """
        Estimate memory usage for a model.
        
        Args:
            model_key: Model key from config
            model_type: Type of model ('language' or 'vision')
            
        Returns:
            Memory usage estimates
        """
        # Simple heuristics for memory estimation
        memory_estimates = {
            "language_models": {
                "llama2-7b": {"fp16": "14GB", "fp32": "28GB", "8bit": "7GB"},
                "llama3-8b": {"fp16": "16GB", "fp32": "32GB", "8bit": "8GB"},
            },
            "vision_models": {
                "resnet18": {"fp16": "50MB", "fp32": "100MB"},
                "resnet50": {"fp16": "100MB", "fp32": "200MB"},
            }
        }
        
        config_key = f"{model_type}_models"
        if config_key not in memory_estimates or model_key not in memory_estimates[config_key]:
            return {"estimate": "Unknown"}
        
        return memory_estimates[config_key][model_key]


# Convenience functions for direct use
def load_language_model(model_key: str, **kwargs) -> Tuple[Any, Any]:
    """Convenience function to load language model."""
    loader = ModelLoader()
    return loader.load_language_model(model_key, **kwargs)


def load_vision_model(model_key: str, **kwargs) -> nn.Module:
    """Convenience function to load vision model."""
    loader = ModelLoader()
    return loader.load_vision_model(model_key, **kwargs)


def get_vision_transforms(model_key: str, **kwargs) -> transforms.Compose:
    """Convenience function to get vision transforms."""
    loader = ModelLoader()
    return loader.get_vision_transforms(model_key, **kwargs)


if __name__ == "__main__":
    # Test the model loader
    loader = ModelLoader()
    
    print("Available models:")
    models = loader.list_available_models()
    for model_type, model_list in models.items():
        print(f"  {model_type}: {model_list}")
    
    print("\nMemory estimates:")
    for model_key in models["language_models"]:
        estimates = loader.estimate_memory_usage(model_key, "language")
        print(f"  {model_key}: {estimates}")