import copy
from dataclasses import dataclass, field
from logging import Logger as LoggerType
from typing import List, Optional, Union, Dict, Any
import torch
from torch import nn
from transformers import AutoTokenizer, AutoConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from utils.model_config import LAYER_CLASS_REGISTRY, MODEL_CLASS_MAP, MODEL_SCHEMA_REGISTRY, ModelMetadata
from utils.layer_utils import LayerSchema


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to load or fine-tune.
    This class supports flexible model types, optional quantization, and extra kwargs for config/model.
    """

    model_name_or_path: str = field(
        default='meta-llama/Llama-2-7b-hf',
        metadata={
            "help": (
                "The model checkpoint to use. Set to a path or Hugging Face model ID."
            )
        },
    )
    model_type: str = field(
        default='causal_lm',
        metadata={
            "help": (
                "The task-specific model type to load."
            ),
            "choices": MODEL_CLASS_MAP.keys(),
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={"help": "Tokenizer name or path if different from model_name_or_path."}
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={"help": "Config name or path if different from model_name_or_path."}
    )

    use_fast_tokenizer: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to use a fast tokenizer (backed by HuggingFace tokenizers library)."}
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (branch, tag, or commit)."}
    )
    token: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The HF Hub token to use for downloading gated models. "
                "If not set, uses the one stored by `huggingface-cli login`."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to trust custom code from Hugging Face Hub. Use only with trusted repositories."
            )
        },
    )
    torch_dtype: str = field(
        default='auto',
        metadata={
            "help": (
                "Override the default `torch.dtype` for model weights. If `auto`, inferred from model. "
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        }
    )
    low_cpu_mem_usage: bool = field(
        default=True,
        metadata={
            "help": (
                "Load model using minimal CPU memory (useful for LLMs). "
                "Parameters are loaded lazily from checkpoint."
            )
        }
    )
    model_kwargs: Optional[Union[str, Dict[str, Any]]] = field(
        default_factory=dict,
        metadata={"help": "Extra arguments to pass to from_pretrained. Use JSON string on CLI, or dict in YAML/JSON file."}
    )
    config_kwargs: Union[str, Dict[str, Any]] = field(
        default_factory=dict,
        metadata={"help": "Extra arguments to pass to AutoConfig. Use JSON string on CLI, or dict in YAML/JSON file."}
    )
    tokenizer_kwargs: Union[str, Dict[str, Any]] = field(
        default_factory=dict,
        metadata={"help": "Extra arguments to pass to AutoTokenizer. Use JSON string on CLI, or dict in YAML/JSON file."}
    )


def get_model_class(model_type: str | None) -> type[_BaseAutoModelClass]:
    """
    Get the appropriate model class based on the model type.
    
    Args:
        model_type (str): The type of model (e.g., "causal_lm", "token_classification").
        
    Returns:
        AutoModel: The corresponding model class.
    """
    if model_type is None or model_type not in MODEL_CLASS_MAP:
        raise ValueError(f"Unsupported model type: {model_type}")
    
    return MODEL_CLASS_MAP[model_type]


def resolve_schema_from_model(model: nn.Module) -> List[LayerSchema]:
    class_name = model.__class__.__name__
    if class_name in MODEL_SCHEMA_REGISTRY:
        return MODEL_SCHEMA_REGISTRY[class_name]

    layers_schema= []
    for module in model.modules():
        layer_class = module.__class__.__name__
        if layer_class in LAYER_CLASS_REGISTRY:
            layers_schema.append(LAYER_CLASS_REGISTRY[layer_class])
        else:
            raise ValueError(f"No schema registered for layer type: {layer_class} or model class: {class_name}")
    return layers_schema


def load_model_and_tokenizer(logger: LoggerType, model_args: ModelArguments, device: str = 'cpu', cache_dir: Optional[str] = None):
    # Load configuration
    config_kwargs = {
        "cache_dir": cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }

    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, 
                                            **config_kwargs, **model_args.config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, 
                                            **config_kwargs, **model_args.config_kwargs)
    else:
        raise ValueError("Either config_name or model_name_or_path must be provided to load the model configuration.")

    # Load Tokenizer
    tokenizer_kwargs = {
        "cache_dir": cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, 
                                                  **tokenizer_kwargs, **model_args.tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, 
                                                  **tokenizer_kwargs, **model_args.tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )
    tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

    # Load Model
    model_cls = get_model_class(model_args.model_type)
    if model_args.model_name_or_path:
        # Convert torch_dtype if passed as string
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None]
            else getattr(torch, model_args.torch_dtype)
        )
        model = model_cls.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=cache_dir,
            revision=model_args.model_revision,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        raise ValueError("You must provide a model_name_or_path to load the model.")
    
    schema = resolve_schema_from_model(model)

    logger.info(f'Creating Student Model')
    model = model.to('cpu')
    student_model = copy.deepcopy(model.eval())
    student_model.eval()

    out = ModelMetadata(config=config, tokenizer=tokenizer,
                        teacher_model=model, student_model=student_model, schema=schema)
    return out