import os
import fire
import json

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import lm_eval
from lm_eval.utils import (
    setup_logging,
    make_table,
)

def init_profiler(model):
    for layer in model.base_model.model.model.layers:
        layer.self_attn.q_proj.profile_ = True
        layer.self_attn.k_proj.profile_ = True
        layer.self_attn.v_proj.profile_ = True
        layer.self_attn.o_proj.profile_ = True
        layer.mlp.up_proj.profile_ = True
        layer.mlp.down_proj.profile_ = True
        layer.mlp.gate_proj.profile_ = True

        layer.self_attn.q_proj.profiler_ = {}
        layer.self_attn.k_proj.profiler_ = {}
        layer.self_attn.v_proj.profiler_ = {}
        layer.self_attn.o_proj.profiler_ = {}
        layer.mlp.up_proj.profiler_ = {}
        layer.mlp.down_proj.profiler_ = {}
        layer.mlp.gate_proj.profiler_ = {}
    return model

def get_routing_weights(model, task_name):
    routing_weights_dict = {}
    for layer_idx, layer in enumerate(model.base_model.model.model.layers):
        
        routing_weights_dict[f'layer.{layer_idx}.self_attn.wq'] = (layer.self_attn.q_proj.profiler_['routing_weights']/layer.self_attn.q_proj.profiler_['cnt']).tolist()
        routing_weights_dict[f'layer.{layer_idx}.self_attn.wk'] = (layer.self_attn.k_proj.profiler_['routing_weights']/layer.self_attn.k_proj.profiler_['cnt']).tolist()
        routing_weights_dict[f'layer.{layer_idx}.self_attn.wv'] = (layer.self_attn.v_proj.profiler_['routing_weights']/layer.self_attn.v_proj.profiler_['cnt']).tolist()
        routing_weights_dict[f'layer.{layer_idx}.self_attn.wo'] = (layer.self_attn.o_proj.profiler_['routing_weights']/layer.self_attn.o_proj.profiler_['cnt']).tolist()
        routing_weights_dict[f'layer.{layer_idx}.mlp.w1'] = (layer.mlp.up_proj.profiler_['routing_weights']/layer.mlp.up_proj.profiler_['cnt']).tolist()
        routing_weights_dict[f'layer.{layer_idx}.mlp.w2'] = (layer.mlp.down_proj.profiler_['routing_weights']/layer.mlp.down_proj.profiler_['cnt']).tolist()
        routing_weights_dict[f'layer.{layer_idx}.mlp.w3'] = (layer.mlp.gate_proj.profiler_['routing_weights']/layer.mlp.gate_proj.profiler_['cnt']).tolist()

    save_file = f'output_logs/routing_weights_moore_235_{task_name}.json'
    with open(save_file, "w") as f:
        json.dump(routing_weights_dict, f, indent=4)
    
def load_adapter_weights(
    model,
    adapter_config,
    adapter_weights,
):
    name_mapping = {
        "moe_gate.weight": "gate.weight",
        "moe_gate.task_embedding": "task_embedding.weight",
        "moe_gate.task_linear": "task_linear.weight",
        "moe_gate.up.weight": "gate_up.weight",
        "lora_A": "lora_a_",
        "lora_B": "lora_b_",
    }
    name_skip = [
        "base_layer",
        "lora_A",
        "lora_B",
        "svd_U",
        "svd_S",
        "svd_Vh",
        "layernorm",
    ]
    new_adapter_weights = {}
    for name, params in adapter_weights.items():
        if adapter_config["peft_type"] == 'LORA':
            new_name = '.'.join(name.split('.')[4:])
        else:
            new_name = '.'.join(name.split('.')[2:])
        for k, v in name_mapping.items():
            new_name = new_name.replace(k, v)
        new_adapter_weights[new_name] = params

    params_dict = dict(model.base_model.model.model.layers.named_parameters())
    with torch.no_grad():
        for name, params in params_dict.items():
            if any([skip in name for skip in name_skip]):
                continue
            new_name = name.replace(".default", "")
            if new_name not in new_adapter_weights:
                continue
            adapter_params = new_adapter_weights[new_name]
            assert params.shape == adapter_params.shape
            params.copy_(adapter_params)
    
    return model

def load_adapter(name_or_path: str):
    with open(
        name_or_path + os.sep + "adapter_config.json", "r", encoding="utf8"
    ) as fp:
        adapter_config = json.load(fp)
    adapter_weight = torch.load(
        name_or_path + os.sep + "adapter_model.bin",
        weights_only=False,
    )

    return adapter_config, adapter_weight

def main(
    base_model: str = "meta-llama/Llama-3.1-8B-Instruct",
    task_names: str = "mmlu_us_foreign_policy",
    adapter_path: str = None,
    num_fewshot: int = 0,
):
    setup_logging("INFO")
    model = AutoModelForCausalLM.from_pretrained(
        base_model, 
        torch_dtype=torch.bfloat16, 
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    if adapter_path:
        adapter_config, adapter_weight = load_adapter(adapter_path)
        lora_config = LoraConfig(
            init_lora_weights=adapter_config["peft_type"],
            r=adapter_config["r"],
            lora_alpha=adapter_config["lora_alpha"],
            lora_dropout=adapter_config["lora_dropout"], # Since the component of the PiSSA adapter are the principal singular values and vectors, dropout should be set to 0 to avoid random discarding.
            target_modules=adapter_config["target_modules"],
            task_type="CAUSAL_LM",
        )
        peft_model = get_peft_model(model, lora_config)
        peft_model.print_trainable_parameters()
        peft_model = load_adapter_weights(peft_model, adapter_config, adapter_weight)
    else:
        peft_model = model
    
    task_names = task_names.split(";")
    for task_name in task_names:

        # init_profiler(peft_model)
        model = lm_eval.models.huggingface.HFLM(
            pretrained=peft_model,
            tokenizer=tokenizer,
        )

        task_manager = lm_eval.tasks.TaskManager()

        num_fewshot = None
        fewshot_as_multiturn = False
        apply_chat_template = False
        gen_kwargs = None
        limit = None
        model_args = ""
        
        if 'mmlu' in task_name:
            num_fewshot = 5
        elif'gsm8k' in task_name:
            num_fewshot = 8
            fewshot_as_multiturn = True
            apply_chat_template = True
        elif 'humaneval' in task_name:
            fewshot_as_multiturn = True
            apply_chat_template = True
        elif 'bbh' in task_name:
            limit = 50
        elif 'mbpp' in task_name:
            fewshot_as_multiturn = True
            apply_chat_template = True
        elif 'ifeval' in task_name:
            fewshot_as_multiturn = True
            apply_chat_template = True
        elif 'gpqa' in task_name:
            model_args = "gpu_memory_utilization=0.8"

        # Setting `task_manager` to the one above is optional and should generally be done
        # if you want to include tasks from paths other than ones in `lm_eval/tasks`.
        # `simple_evaluate` will instantiate its own task_manager if it is set to None here.
        results = lm_eval.simple_evaluate( # call simple_evaluate
            model=model,
            model_args=model_args,
            tasks=[task_name],
            num_fewshot=num_fewshot,
            task_manager=task_manager,
            fewshot_as_multiturn=fewshot_as_multiturn,
            apply_chat_template=apply_chat_template,
            confirm_run_unsafe_code=True,
            limit=limit,
            gen_kwargs=gen_kwargs,
        )
        print(make_table(results))
        # get_routing_weights(peft_model, task_name)


if __name__ == "__main__":
    fire.Fire(main)