import argparse
import numpy as np
import pandas as pd
import torch
from utils.data import set_seed
from eval.lm_eval import eval
from eval.pred_long_bench import *
from eval.eval_long_bench import *


from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from model_impl.mistral import MistralAttention_sparse_mean
from utils.model_loader import load_model_and_tokenizer, get_model_path, save_model_and_tokenizer
from compression_pipeline import compress
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='Path or name of the LLaMA model to load.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of samples used for calibration.')

    parser.add_argument('--k_rank', type=int, default=512, help='Low-rank dimension for compressing key representations.')
    parser.add_argument('--k_high_rank', type=int, default=2048, help='High-rank dimension used in recent windows.')
    parser.add_argument('--v_bits', type=int, default=2, help='Quantization bit-width for the value cache.')

    parser.add_argument('--sparsity', type=int, default=16, help='Sparse level of tokens to retain during sparse attention selection; overridden by "budget" when budget is not None')
    parser.add_argument('--groupsize', type=int, default=32, help='Group size used in group-wise quantization.')

    parser.add_argument('--residual_length', type=int, default=128, help='Sliding window size for recent token selection;overridden by "budget" when budget is not None')

    parser.add_argument('--tasks', type=str, default="gsm8k,coqa", help="zero shot evaluation tasks")
    parser.add_argument('--datasets', type=str, default="zero_shot", help="zero_shot evaluation or longbench evaluation")

    parser.add_argument('--withppl', action='store_true', help="Enable perplexity evaluation on WikiText2.")
    parser.add_argument('--withprune', action='store_true', help="Enable pruning during KV cache compression.")

    parser.add_argument('--skip_layers', type=str, default="", help="Comma-separated list of layers to skip during compression or pruning.")

    parser.add_argument('--budget', type=str, default=None, help='Comma-separated token budget for sparse attention as "sink_token_num,critical_token_num,recent_token_num". '
                                                                'If specified, recent_token_num will override --residual_length and --sparsity')

    parser.add_argument('--topk_rank', type=int, default=128, help='Rank for selecting top-k latent features for reconstruction.')
    parser.add_argument('--with_overlap', action='store_true', help="Enable compute overlap score method")

    args = parser.parse_args()
    set_seed(args.seed)
    dtype = torch.float16

    model_name = args.model.split("/")[-1]
    config = AutoConfig.from_pretrained(args.model)
    config.k_rank = args.k_rank
    config.k_high_rank = args.k_high_rank
    config.k_bits = args.k_bits
    config.v_bits = args.v_bits
    config.group_size = args.groupsize
    config.residual_length = args.residual_length
    config.nsamples = args.nsamples
    config.sparsity = args.sparsity
    config.sparse_rank = args.sparse_rank
    config.method = args.method
    config.finetuned = args.withfinetune
    config.with_overlap = args.witho_verlap
    config.skip_layers = args.skip_layers
    config.budget = [int(n) for n in args.budget.split(',')] if args.budget is not None else None
    config.topk_rank = args.topk_rank
    # import pdb;pdb.set_trace()
    if args.withprune:
        compress(args.model, config, device="cuda:0", save=True)
    if args.method == "origin" or args.method == "mean_sparse":
        config.is_compress = False
    else:
        config.is_compress = True
    print(args.method)
    # import pdb;pdb.set_trace()
    model, tokenizer = load_model_and_tokenizer(args.model, config, compressed=config.is_compress)
    
    datasets_list = args.datasets.split(",")
    print(args)
    if "zero_shot" in datasets_list:
        results_table = eval(model=model, tokenizer=tokenizer, tasks=args.tasks.split(","), withppl=args.withppl, batch_size=1) 
        print(results_table)
    
    if "long_bench" in datasets_list:
        seed_everything(42)
        datasets = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", "gov_report", "qmsum", "multi_news", "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en",  "lcc", "repobench-p"]
        pred_datasets = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", "gov_report", "qmsum", "multi_news", "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en",  "lcc", "repobench-p"]
        # pred_datasets = ["multi_news", "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en",  "lcc", "repobench-p"]
        # datasets = ["triviaqa"]
        # datasets = ["lcc"]
        do_sample = False
        predict_long_bench(model, tokenizer, config, pred_datasets, device="cuda:0", model_name=model_name, do_sample=do_sample)
        scores = eval_long_bench(model_name, config, datasets, do_sample=do_sample)
        # scores is a list of datasets scores one by one
        # i want format it to show in a table like: dataset_name: score
        for dataset, score in zip(datasets, scores):
            print(f"{dataset}: {score}")
        print("avg scores", np.mean(scores))
        csv_path = "result.csv"
        row_data = dict(zip(datasets, scores))
        row_data.update(vars(args))
        df = pd.DataFrame([row_data])
        if os.path.exists(csv_path):
            df.to_csv(csv_path, mode="a", index=False, header=False)
        else:
            df.to_csv(csv_path, index=False)

        

    print(args)