import os
import sys
import random
import numpy as np
import torch
import utils
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from accelerate import infer_auto_device_map
from utils.quant_utils import wrap_to_quant_model, init_weight_quantizer, init_input_quantizer, init_out_quantizer, register_online_had, init_k_quantizer, init_v_quantizer
import utils.model_utils as model_utils
import utils.rotation_utils as rotation_utils
from main import evaluate
from utils.train_utils import load_json_as_namespace,create_logger
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_in_model
from utils.snn_utils import wrap_to_snn_model_tdf, replicate_past_key_values, wrap_to_snn_model_new,  wrap_to_snn_model
from datasets import load_dataset
torch.backends.cudnn.benchmark = True

import gc
import torch
def free_model(model):
    del model
    torch.cuda.empty_cache()
    gc.collect()
def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--quant_model_path", type=str, help="model path of quantized model")
    parser.add_argument("--output_dir", default="./log/test_snn", type=str, help="direction of logging file")
    parser.add_argument("--real_quant", default=False, action="store_true",
                        help="use real quantization instead of fake quantization, can reduce memory footprint")
    parser.add_argument("--ppl_seqlen", type=int, default=2048, help="lenth of the training sequence.")
    parser.add_argument("--seed", type=int, default=2, help="Seed for sampling the calibration data.")
    parser.add_argument("--T", type=int, default=2
                        , help="time step")
    # parser.add_argument("--L", type=int, default=8, help="spike neuron")
    parser.add_argument("--eval_ppl", action="store_true",help="evaluate perplexity on wikitext2 and c4 with 2048 context length")
    parser.add_argument("--avg_neuron", action="store_true",help="set average in spike neuron")
    parser.add_argument("--eval_tasks", type=str,default="", help="exampe:piqa,arc_easy,arc_challenge,hellaswag,winogrande")
    parser.add_argument("--eval_batch_size", type=int, default=16)
    parser.add_argument("--max_memory", type=str, default="40GiB",help="The maximum memory of each GPU")
    parser.add_argument("--esmode", type=str, default="TDF",help="TDF or TIF")
    # parser.set_defaults(avg_neuron=True)
    


    os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    args = parser.parse_args()
    print(">>> Effective args.T:", args.T)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    # init logger
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    output_dir = Path(args.output_dir)
    logger = create_logger(output_dir)

    quant_config = load_json_as_namespace(os.path.join(args.quant_model_path, 'prefixequant_config.json'))
    # if quant_config['set_prefixed_tokens']:
    if quant_config.set_prefixed_tokens:
        prefixed_key_values = torch.load(os.path.join(args.quant_model_path, 'prefixed_key_values.pth'))
    else:
        prefixed_key_values = None

    logger.info(args)
    prefixed_key_values = replicate_past_key_values(prefixed_key_values, args.T)
    # init quantized model
    config = AutoConfig.from_pretrained(args.quant_model_path,trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(args.quant_model_path, use_fast=False,legacy=False,trust_remote_code=True)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(args.quant_model_path, config=config, device_map='cpu',torch_dtype=torch.float16,trust_remote_code=True)
    wrap_to_quant_model(model)
    # register on-line hadadamrd transformation
    if quant_config.down_online_had:
        register_online_had(model)
    # wrap rope for online_had and rope output capture
    rope_function_name = model_utils.get_rope_function_name(model)
    layers = model_utils.get_layers(model)
    for layer in layers:
        rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
                    layer.self_attn, 
                    rope_function_name, 
                    config=model.config,
                    online_had=quant_config.qk_online_had)   

    # init weight quantizer
    if quant_config.wbits < 16:
        logger.info('init weight quantizer')
        init_weight_quantizer(quant_config, model, logger=logger, minmax_init=False)

    # init input quantizer
    if quant_config.input_bits < 16:
        logger.info('init input quantizer')
        init_input_quantizer(quant_config, model, logger=logger, minmax_init=False)

    if quant_config.output_bits < 16:
        logger.info('init output quantizer')
        init_out_quantizer(quant_config, model, logger=logger, minmax_init=False)

    print(model)
    import copy

    device0 = torch.device("cuda:0")  
    device1 = torch.device("cuda:1")  


    load_checkpoint_in_model(model, checkpoint=args.quant_model_path, device_map=None, dtype=torch.float16)
    model = model.to("cpu")


    model_ann_0 = copy.deepcopy(model).to(device0)
    model_ann_1 = copy.deepcopy(model).to(device1)


    model_snn_0 = copy.deepcopy(model)  # SpikeLlama_with_esframework
    if args.esmode == "TDF":
        wrap_to_snn_model_tdf(model_snn_0, args)
    else:
        wrap_to_snn_model(model_snn_0, args)

    model_snn_1 = copy.deepcopy(model)  # SpikeLlama
    wrap_to_snn_model_new(model_snn_1, args)


    load_checkpoint_in_model(model_snn_0, checkpoint=args.quant_model_path, device_map=None, dtype=torch.float16)
    load_checkpoint_in_model(model_snn_1, checkpoint=args.quant_model_path, device_map=None, dtype=torch.float16)

    model_snn_0 = model_snn_0.to(device0)
    model_snn_1 = model_snn_1.to(device1)

    for m in [model_ann_0, model_ann_1, model_snn_0, model_snn_1]:
        m.half()
        for param in m.parameters():
            param.requires_grad = False


    tokenizer = AutoTokenizer.from_pretrained(args.quant_model_path, trust_remote_code=True)
    ds = load_dataset("cais/mmlu", "moral_scenarios")
    samples = random.sample(list(ds['test']), 100)
    prompts = [item['question'] for item in samples]

    def evaluate_model(model_ann, model_snn, prompts, label, device):
        mean_l2_list = []
        for prompt in prompts:
            inputs = tokenizer(prompt, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            input_ids = inputs["input_ids"]
            seq_len = input_ids.shape[1]

            with torch.no_grad():
                x_ann = model_ann.model.embed_tokens(input_ids)
                pos_ann = torch.arange(seq_len, device=device).unsqueeze(0)
                out_ann = model_ann.model.layers[0](x_ann, position_ids=pos_ann)
                out_ann = out_ann[0] if isinstance(out_ann, tuple) else out_ann
                out_ann = out_ann.squeeze(0)

                x_snn = model_snn.model.embed_tokens(input_ids)
                x_snn = x_snn.expand(args.T, -1, -1, -1)
                x_snn = x_snn.contiguous().view(-1, seq_len, x_snn.shape[-1])
                pos_snn = torch.arange(seq_len, device=device).unsqueeze(0).expand(x_snn.shape[0], -1)

                out_snn = model_snn.model.layers[0](x_snn, position_ids=pos_snn)
                out_snn = out_snn[0] if isinstance(out_snn, tuple) else out_snn
                out_snn = out_snn.mean(dim=0)

                l2_errors = torch.norm(out_snn - out_ann, dim=-1)
                mean_l2 = l2_errors.mean().item()
                mean_l2_list.append(mean_l2)

                print(f"[{label}] Prompt:", prompt)
                print(f"[{label}] Mean L2 error:", mean_l2)

        total_mean = sum(mean_l2_list) / len(mean_l2_list)
        total_max = max(mean_l2_list)
        total_min = min(mean_l2_list)
        error_range = total_max - total_min

        print("=" * 50)
        print(f"[{label}] Mean L2 Error: {total_mean:.6f}")
        print(f"[{label}] Max  L2 Error: {total_max:.6f}")
        print(f"[{label}] Min  L2 Error: {total_min:.6f}")
        print(f"[{label}] Error Range  : {error_range:.6f}")

        return {
            "label": label,
            "mean": total_mean,
            "max": total_max,
            "min": total_min,
            "range": error_range
        }


    res_model = evaluate_model(model_ann_0, model_snn_0, prompts, "SpikeLlama_with_esframework", device0)
    res_model_new = evaluate_model(model_ann_1, model_snn_1, prompts, "SpikeLlama", device1)


    result_path = f"./experiments/results/tdf/mmlu_l2_T={args.T}_result.txt"
    os.makedirs(os.path.dirname(result_path), exist_ok=True)
    with open(result_path, "w") as f:
        f.write("your result here")

    with open(result_path, "w") as f:
        f.write("Evaluation Result on MMLU (moral_scenarios)\n\n")
        for res in [res_model, res_model_new]:
            f.write(f"{res['label']}:\n")
            f.write(f"  Mean L2 Error = {res['mean']:.6f}\n")
            f.write(f"  Max  L2 Error = {res['max']:.6f}\n")
            f.write(f"  Min  L2 Error = {res['min']:.6f}\n")
            f.write(f"  Error Range   = {res['range']:.6f}\n\n")

    print(f"\n>>> Results saved to {result_path}")


if __name__ == "__main__":
    print(sys.argv)
    main()
