from transformers import AutoModel, AutoTokenizer
from peft import (
    LoraConfig, 
    PromptTuningConfig, 
    PromptEncoderConfig, 
    get_peft_model, 
    PeftModel, 
    TaskType
)
import torch
from omegaconf import OmegaConf
from config import MODEL_TYPE, FINETUNING_TYPE, get_type, MODEL_PATHS_MAPPING
import torch.nn as nn

def get_model_by_config(config):
    """Select different models based on config file"""
    model_type: MODEL_TYPE = get_type(MODEL_TYPE, config.get("model", None))
    finetuning_type: FINETUNING_TYPE = get_type(
        FINETUNING_TYPE, config.get("finetuning_method", None)
    )
    if model_type not in MODEL_PATHS_MAPPING.keys():
        raise NotImplementedError(
            f"No path found in MODEL_PATHS_MAPPING for {model_type.__name__}, please specify one."
        )

    elif finetuning_type is FINETUNING_TYPE.LORA:
        return get_lora_model(model_type, config)

    elif finetuning_type is FINETUNING_TYPE.PTUNING:
        return get_ptuning_model(model_type, config)
    elif finetuning_type is FINETUNING_TYPE.PROMPT_TUNING:
        return get_prompt_tuning_model(model_type, config)

    elif finetuning_type is FINETUNING_TYPE.HIRA:
        return get_hira_models(model_type, config)
    elif finetuning_type is FINETUNING_TYPE.NARA:
        return get_nara_models(model_type, config)
    else:
        raise NotImplementedError(
            f"Unsupported finetuning method: {finetuning_type}, unexpected error may occur."
        )

def get_prompt_tuning_model(model_type, config):
    print("Initializing Prompt Tuning (Standard)...")
    

    model_path = MODEL_PATHS_MAPPING[model_type]
    
    base_model = AutoModel.from_pretrained(
        model_path, trust_remote_code=True
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_path, trust_remote_code=True
    )

    decoder_resume_path = _get_resume_path(config)
    _freeze_base_model(base_model)

    if decoder_resume_path:
        model = PeftModel.from_pretrained(base_model, decoder_resume_path, is_trainable=True)
        print(f"[Prompt Tuning] Resumed from: {decoder_resume_path}")
    else:

        ft_params = OmegaConf.to_container(config.finetuning_parameters, resolve=True)

        ft_params["tokenizer_name_or_path"] = model_path
        

        if "task_type" not in ft_params:
            ft_params["task_type"] = TaskType.CAUSAL_LM


        peft_config = PromptTuningConfig(**ft_params)
        model = get_peft_model(base_model, peft_config)

    model.print_trainable_parameters()
    return model, tokenizer



def _get_resume_path(config):
    """Helper to extract resume path safely"""
    if hasattr(config, "train") and hasattr(config.train, "decoder_resume_path"):
        val = config.train.decoder_resume_path
        if isinstance(val, str) and val.strip():
            return val.strip()
    return None

def _freeze_base_model(model):
    """Helper to freeze model parameters"""
    for param in model.parameters():
        param.requires_grad = False
        if param.ndim == 1:
            param.data = param.data.to(torch.float32)

def get_lora_model(model_type, config):
    base_model = AutoModel.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )

    decoder_resume_path = _get_resume_path(config)
    _freeze_base_model(base_model)
    if decoder_resume_path:
        model = PeftModel.from_pretrained(base_model, decoder_resume_path, is_trainable=True)
        print(f"[LoRA] Resumed from: {decoder_resume_path}")
    else:
        ft_params = OmegaConf.to_container(config.finetuning_parameters, resolve=True)
        peft_config = LoraConfig(**ft_params)

        # Standard LoRA hack for some models (optional, kept from original code)
        tr = base_model.model.transformer
        if "ff_out" in tr._modules:
            tr._modules["lm_head_ff_out_tmp"] = tr._modules.pop("ff_out")
            try:
                model = get_peft_model(base_model, peft_config)
            finally:
                tr._modules["ff_out"] = tr._modules.pop("lm_head_ff_out_tmp")
        else:
             model = get_peft_model(base_model, peft_config)

    model.print_trainable_parameters()
    return model, tokenizer


def get_ptuning_model(model_type, config):
    base_model = AutoModel.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )

    decoder_resume_path = _get_resume_path(config)
    _freeze_base_model(base_model)

    if decoder_resume_path:
        model = PeftModel.from_pretrained(base_model, decoder_resume_path, is_trainable=True)
        print(f"[P-Tuning] Resumed from: {decoder_resume_path}")
    else:
        ft_params = OmegaConf.to_container(config.finetuning_parameters, resolve=True)
        
        if "task_type" not in ft_params:
             ft_params["task_type"] = TaskType.CAUSAL_LM

        peft_config = PromptEncoderConfig(**ft_params)
        model = get_peft_model(base_model, peft_config)

    model.print_trainable_parameters()
    return model, tokenizer



def get_hira_models(model_type, config):
    from hira import PeftModel, HiraConfig, get_peft_model

    base_model = AutoModel.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )

    decoder_resume_path = None

    if hasattr(config, "train") and hasattr(config.train, "decoder_resume_path"):
        val = config.train.decoder_resume_path
        if isinstance(val, str) and val.strip():
            decoder_resume_path = val.strip()

    for param in base_model.parameters():
        param.requires_grad = False  # freeze the model - train adapters later
        if param.ndim == 1:
            # cast the small parameters (e.g. layernorm) to fp32 for stability
            param.data = param.data.to(torch.float32)

    if decoder_resume_path:
        model = PeftModel.from_pretrained(
            base_model, decoder_resume_path, is_trainable=True
        )
        print(f"[HIRA] Resumed from: {decoder_resume_path}")
    else:
        ft_params = OmegaConf.to_container(config.finetuning_parameters, resolve=True)

        peft_config = HiraConfig(**ft_params)

        tr = base_model.model.transformer
        assert "ff_out" in tr._modules, "no transformer.ff_out found "
        tr._modules["lm_head_ff_out_tmp"] = tr._modules.pop("ff_out")
        try:
            model = get_peft_model(base_model, peft_config)
        finally:

            tr._modules["ff_out"] = tr._modules.pop("lm_head_ff_out_tmp")
    model.print_trainable_parameters()

    return model, tokenizer

def get_nara_models(model_type, config):
    if config.data.batch_size != 1:
        raise ValueError("NARA currently only supports batch_size=1.")

    from nara import PeftModel, NARAConfig, get_peft_model
    from omegaconf import OmegaConf

    base_model = AutoModel.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATHS_MAPPING[model_type], trust_remote_code=True
    )

    decoder_resume_path = None

    if hasattr(config, "train") and hasattr(config.train, "decoder_resume_path"):
        val = config.train.decoder_resume_path
        if isinstance(val, str) and val.strip():
            decoder_resume_path = val.strip()

    for param in base_model.parameters():
        param.requires_grad = False  # freeze the model - train adapters later
        if param.ndim == 1:
            # cast the small parameters (e.g. layernorm) to fp32 for stability
            param.data = param.data.to(torch.float32)

    if decoder_resume_path:
        model = PeftModel.from_pretrained(
            base_model, decoder_resume_path, is_trainable=True
        )
        print(f"[NARA] Resumed from: {decoder_resume_path}")
    else:
        ft_params = OmegaConf.to_container(config.finetuning_parameters, resolve=True)

        # 1. Extract lora_ckpt_path and remove it from params passed to Config
        lora_ckpt_path = ft_params.pop("lora_ckpt_path", None)

        peft_config = NARAConfig(**ft_params)

        if "ff_out" in ft_params["target_modules"]:
            tr = base_model.model.transformer
            assert "ff_out" in tr._modules, "no transformer.ff_out found "
            tr._modules["lm_head_ff_out_tmp"] = tr._modules.pop("ff_out")
            try:
                model = get_peft_model(base_model, peft_config)
            finally:
                tr._modules["ff_out"] = tr._modules.pop("lm_head_ff_out_tmp")
        else:
            model = get_peft_model(base_model, peft_config)

        # 2. Load Pre-trained LoRA weights if path is provided
        if lora_ckpt_path:
            print(f"[NARA] Loading warm-start LoRA weights from: {lora_ckpt_path}")
            # Ensure the method exists (it is part of ContextLoRAModel)
            if hasattr(model, "load_lora_only"):
                model.load_lora_only(lora_ckpt_path)
            else:
                # Handle case where model might be wrapped (e.g. PeftModel wrapping ContextLoRAModel)
                # Usually get_peft_model returns the PeftModel, and the active adapter is the ContextLoRAModel
                if hasattr(model.base_model, "load_lora_only"):
                     model.base_model.load_lora_only(lora_ckpt_path)
                else:
                     print("[NARA] Warning: could not find load_lora_only method on model.")

    model.print_trainable_parameters()
    return model, tokenizer