"""
Classes for loading pretrained models
"""
from os.path import join
from omegaconf import OmegaConf

import torch
import torch.nn as nn

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
# from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training


def get_pretrained_loader(pretrained_model_name_or_path: str,
                          huggingface_token: str = None,
                          **model_kwargs: any):
    """
    Return the appropriate loader for the pretrained model
    """

    if 'lama' in pretrained_model_name_or_path:  # Llama or llama
        return PretrainedLlamaLoader(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            huggingface_token=huggingface_token,
            **model_kwargs,
        )
    elif 'wen3' in pretrained_model_name_or_path: # Qwen3 or qwen3
        return PretrainedQwen3Loader(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            huggingface_token=huggingface_token,
            **model_kwargs,
        )
    elif 'wen2' in pretrained_model_name_or_path: # Qwen2.5 or qwen2
        return PretrainedQwen2Loader(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            huggingface_token=huggingface_token,
            **model_kwargs,
        )
    else:
        print(f'-> {pretrained_model_name_or_path} using default pretrained model loader')
        return PretrainedModelLoader(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            huggingface_token=huggingface_token,
            **model_kwargs,
        )


class PretrainedModelLoader():
    """
    Class for loading a pretrained model. 
    Example:
      model_loader = PretrainedModelLoader(**model_kwargs)
      model = model_loader.load()
    """
    def __init__(self,
                 pretrained_model_name_or_path: str,
                 cache_dir: str = None,
                 return_dict: bool = True,  # False
                 device_map: str = 'cpu',
                 low_cpu_mem_usage: bool = True,
                 torch_dtype: str = 'bfloat16',
                 rope_theta: float = 10000.,
                 attn_implementation: str = 'sdpa',  # eager
                 load_in_8bit: bool = False,
                 load_in_4bit: bool = False,
                 huggingface_token: str = None,
                 peft_id: str = None,
                 rope_scaling: dict = None,
                 **other_kwargs: any) -> None:

        print(f'-> Using {attn_implementation} attention')
        
        self.loading_kwargs = {
            'pretrained_model_name_or_path': pretrained_model_name_or_path,
            'cache_dir': cache_dir,
            'return_dict': return_dict,
            'load_in_8bit': load_in_8bit,
            'load_in_4bit': load_in_4bit,
            'device_map': device_map,
            'low_cpu_mem_usage': low_cpu_mem_usage,
            'torch_dtype': getattr(torch, torch_dtype),
            'rope_theta': rope_theta,
            'attn_implementation': attn_implementation,
        }
        if rope_scaling is not None:  # Llama 3.1 patch
            rope_scaling = OmegaConf.to_container(rope_scaling)
            self.loading_kwargs['rope_scaling'] = rope_scaling
        for k, v in other_kwargs.items():
            self.loading_kwargs[k] = v

        self.quantization = load_in_8bit or load_in_4bit
        self.peft_id = peft_id
        self.gradient_checkpointing = False
        if huggingface_token is not None:  # for gated models, e.g., Llama 3
            self.loading_kwargs['token'] = huggingface_token

        if self.quantization:
            raise NotImplementedError
            
        
    def load(self) -> nn.Module:
        """
        Load pretrained model
        """
        model = AutoModelForCausalLM.from_pretrained(**self.loading_kwargs)
        if self.quantization:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=self.gradient_checkpointing,
                gradient_checkpointing_kwargs={'use_reentrant': False},
            )
        return model

    def load_tokenizer(self):
        """
        Load pretrained tokenizer
        """
        try:
            return AutoTokenizer.from_pretrained(**self.loading_kwargs)
        except Exception as e:
            print("-> Error with `AutoTokenizer.from_pretrained(**self.loading_kwargs)`:", e)
            print("-> Trying `LlamaTokenizer.from_pretrained(**self.loading_kwargs)`")
            # MZ 6/1: Mistral-7B-Instruct-v0.3 in Transformers v4.36 doesn't work with the above
            return LlamaTokenizer.from_pretrained(**self.loading_kwargs)  


class PretrainedLlamaLoader(PretrainedModelLoader):
    def load(self, model_type: str = None, ):
        if model_type is None:
            from transformers import LlamaForCausalLM as model_class
        elif 'kivi_kvlinc' in model_type or 'kv_linc' in model_type or 'kivi_attention' in model_type or 'quarot_attention' in model_type or 'resq_attention' in model_type or 'gear_attention' in model_type:
            from .modeling_llama import KVLincLlamaForCausalLM as model_class
        else:
            if model_type == 'flash_attention_2':
                self.loading_kwargs['attn_implementation'] = model_type
            from transformers import AutoModelForCausalLM as model_class
            print('-> Loading from AutoModelForCausalLM')

        model = model_class.from_pretrained(**self.loading_kwargs)
        if self.peft_id is not None:
            from peft import PeftModel
            print('-> Loading PEFT checkpoint')
            model = PeftModel.from_pretrained(
                model, 
                self.peft_id,
                torch_dtype=self.loading_kwargs['torch_dtype'],
                device_map='cpu',
                cache_dir=self.loading_kwargs['cache_dir']
            ).merge_and_unload()
            
        if self.quantization:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=self.gradient_checkpointing,
                gradient_checkpointing_kwargs={'use_reentrant': False},
            )
        return model

    def load_tokenizer(self):
        return AutoTokenizer.from_pretrained(**self.loading_kwargs)

class PretrainedQwen3Loader(PretrainedModelLoader):
    def load(self, model_type: str = None, ):
        if model_type is None:
            from transformers import Qwen3ForCausalLM as model_class
        elif 'kivi_kvlinc' in model_type or 'kv_linc' in model_type or 'kivi_attention' in model_type or 'quarot_attention' in model_type or 'resq_attention' in model_type or 'gear_attention' in model_type:
            from .modeling_qwen3 import KVLincQwen3ForCausalLM as model_class
        else:
            if model_type == 'flash_attention_2':
                self.loading_kwargs['attn_implementation'] = model_type
            from transformers import AutoModelForCausalLM as model_class
            print('-> Loading from AutoModelForCausalLM')

        model = model_class.from_pretrained(**self.loading_kwargs)
        if self.peft_id is not None:
            from peft import PeftModel
            print('-> Loading PEFT checkpoint')
            model = PeftModel.from_pretrained(
                model, 
                self.peft_id,
                torch_dtype=self.loading_kwargs['torch_dtype'],
                device_map='cpu',
                cache_dir=self.loading_kwargs['cache_dir']
            ).merge_and_unload()
            
        if self.quantization:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=self.gradient_checkpointing,
                gradient_checkpointing_kwargs={'use_reentrant': False},
            )
        return model

    def load_tokenizer(self):
        return AutoTokenizer.from_pretrained(**self.loading_kwargs)


class PretrainedQwen2Loader(PretrainedModelLoader):
    def load(self, model_type: str = None, ):
        if model_type is None:
            from transformers import Qwen2ForCausalLM as model_class
        elif 'kivi_kvlinc' in model_type or 'kv_linc' in model_type or 'kivi_attention' in model_type or 'quarot_attention' in model_type or 'resq_attention' in model_type or 'gear_attention' in model_type:
            from .modeling_qwen2 import KVLincQwen2ForCausalLM as model_class
        else:
            if model_type == 'flash_attention_2':
                self.loading_kwargs['attn_implementation'] = model_type
            from transformers import AutoModelForCausalLM as model_class
            print('-> Loading from AutoModelForCausalLM')
        model = model_class.from_pretrained(**self.loading_kwargs)
        if self.peft_id is not None:
            from peft import PeftModel
            print('-> Loading PEFT checkpoint')
            model = PeftModel.from_pretrained(
                model, 
                self.peft_id,
                torch_dtype=self.loading_kwargs['torch_dtype'],
                device_map='cpu',
                cache_dir=self.loading_kwargs['cache_dir']
            ).merge_and_unload()
            
        if self.quantization:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=self.gradient_checkpointing,
                gradient_checkpointing_kwargs={'use_reentrant': False},
            )
        return model

    def load_tokenizer(self):
        return AutoTokenizer.from_pretrained(**self.loading_kwargs)


