from transformers import AutoTokenizer, AutoModel
import torch
from config import (
    MODEL_TYPE,
    FINETUNING_TYPE,
    MODEL_PATHS_MAPPING,
    get_type,
    TASK_TYPE,
    CKPT_MAPPING,
    # set_seed,
)
from typing import Optional
# import os

#get the ckpt model
def get_eval_model(base_model_name:str,
                       peft_name:Optional[str] =None,
                       ft_task:Optional[str] =None,
                       run_time:int=1,
                       f_form:str="linear",
                       zero_lora_init:bool=False,
                       direct_noise:bool=False,
                       ckpts:str="best",
                       training_mode:str="joint",
                       t_mapping:str="poly",
                       fnn_hidden_size:int=32,
                       lr:float=5e-4,
                       use_embedding:bool=False,
                       embedding_dim:int=1,
                       init_c:str="kaiming_uniform_m",
                       input_mode:str="noise_level",
                       density_radius:int=0,
                       rank:int=32,
                       fnn_hidden_size_2:int=512,
                        Embed_components:str="nd_nl",
                        Embed_type:str="fourier",
                        mapper_num_layers:int=2,
                        c_scale:float=1,
                        length_alignment:bool=False,
                        whole_length:bool=False,
                        stage_1:float=0,
                        scale_ab:float=1,
                        clr:float=1e-4,
                        nvt:int=20,
                        h:int=550,
                        epoch:int=1,
                        ectype:str="LSTM",
                        prompt_tuning_init:str="TEXT",
                       ):
    # Load base model
    base_model_type: MODEL_TYPE = get_type(MODEL_TYPE, base_model_name)
    if base_model_type not in MODEL_PATHS_MAPPING.keys():
            raise NotImplementedError(
                f"No path found in MODEL_PATHS_MAPPING for {base_model_type.__name__}, please specify one."
            )
    base_model = AutoModel.from_pretrained(
        MODEL_PATHS_MAPPING[base_model_type],
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATHS_MAPPING[base_model_type], trust_remote_code=True
    )
        
    finetuning_type: FINETUNING_TYPE = get_type(FINETUNING_TYPE, peft_name)
    task_type: TASK_TYPE = get_type(TASK_TYPE, ft_task)
    if peft_name :
        if (
                base_model_type,
                task_type,
                finetuning_type,
            ) not in CKPT_MAPPING.keys():
                raise NotImplementedError(
                    f"No path found in CKPT_MAPPING for {base_model_type.__name__, task_type.__name__, finetuning_type.__name__}, please specify one."
                )
        if finetuning_type in [FINETUNING_TYPE.HIRA]:
            from hira import PeftModel

            model = PeftModel.from_pretrained(
                base_model,
                CKPT_MAPPING[(base_model_type, task_type, finetuning_type)][rank,length_alignment][
                    run_time - 1
                ],
            )
        elif finetuning_type in [FINETUNING_TYPE.NARA]:
            from nara import PeftModel,NARAModel
            model:NARAModel = PeftModel.from_pretrained(
                base_model,
                CKPT_MAPPING[(base_model_type, task_type, finetuning_type)][
                    int(fnn_hidden_size),
                    int(fnn_hidden_size_2),
                    int(rank),
                    embedding_dim,
                    density_radius,
                    Embed_components,
                    float(lr),
                    init_c,
                    stage_1,
                    c_scale,
                    scale_ab,
                    clr,
                    Embed_type,
                    ]
                [
                    run_time - 1
                ],
            )
        elif finetuning_type in [FINETUNING_TYPE.PTUNING]:
            from peft import PeftModel
            model = PeftModel.from_pretrained(
                base_model,
                CKPT_MAPPING[(base_model_type, task_type, finetuning_type)][
                    int(nvt),
                    int(h),
                    int(epoch),
                    float(lr),
                    ectype,
                    ]
                [
                    run_time - 1
                ],
            )
        elif finetuning_type in [FINETUNING_TYPE.PROMPT_TUNING]:
            from peft import PeftModel
            model = PeftModel.from_pretrained(
                base_model,
                CKPT_MAPPING[(base_model_type, task_type, finetuning_type)][
                    int(nvt),
                    prompt_tuning_init,
                    int(epoch),
                    float(lr),
                    ]
                [
                    run_time - 1
                ],
            )
        else:
            # --- Original 'else' block (for vanilla LORA, etc.) ---
            from peft import PeftModel
            model = PeftModel.from_pretrained(
                base_model,
                CKPT_MAPPING[(base_model_type, task_type, finetuning_type)][rank,length_alignment][
                    run_time - 1
                ],
            )
            
            # --- RESTORED LINE (As you correctly pointed out) ---
            model = model.merge_and_unload()
            print("[Info] Loaded PeftModel and called merge_and_unload().")

        # Optionally zero LoRA weights
        if zero_lora_init:
            print("[Info] Applying zero_lora_init...")
            for name, module in model.named_modules():
                if "lora_A" in name or "lora_B" in name:
                    if hasattr(module, "weight") and module.weight is not None:
                        torch.nn.init.zeros_(module.weight)
            print("[Info] zero_lora_init complete.")

    else:
        finetuning_type = None
        model = base_model
        
    return model,tokenizer,finetuning_type