import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict

import torch.types
from tqdm import tqdm

from utils.utils import set_seed, get_prompt_weave_v20_llama3b, get_prompt_2_weave_v20_llama3b, \
    get_prompt_3_weave_v20_llama3b, get_prompt_4_weave_v20_llama3b, get_prompt_5_weave_v20_llama3b, \
    get_prompt_6_weave_v20_llama3b
from datas.get_data import get_data
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F
from utils.utils import compare_retrieval_acc

import numpy as np
import os

model_custom_config = {
    "max_new_tokens": 50,
    "temperature": 0.1,
    "top_p": 0.9
}

def main(args):

    if "weave-mpt1" == args.method:
        import methods.weave_mpt1 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        import models.mpt_7b.weave_attention as weave_attention
        weave_attention.chunk_width = 2047 # 512 # args.push_mpt

    elif "weave-mpt2" == args.method:
        import methods.weave_mpt2 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
    elif "weave-mpt3" == args.method:
        import methods.weave_mpt3 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
    elif "weave-mpt6" == args.method:
        import methods.weave_mpt6 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
        weave_mpt.chunk_width = args.chunk_width
    elif "weave-mpt7" == args.method:
        import methods.weave_mpt7 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
        weave_mpt.chunk_width = args.chunk_width



    dataset = get_data(args.dataset)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    prefix_prompt = get_promt(args.model_path)

    # if "weave-v20-llama3b" in args.method or "vicuna" in args.model_path or "llama" in args.model_path:
    # #     # prefix_prompt = get_prompt_3_weave_v20_llama3b()
    # #     # prefix_prompt = get_prompt_weave_v20_llama3b()
    # #     # prefix_prompt = get_prompt_4_weave_v20_llama3b()
    #     prefix_prompt = get_prompt_5_weave_v20_llama3b()
    #     # prefix_prompt = get_prompt_6_weave_v20_llama3b()
    if "llama-3b" in args.model_path:
        prefix_prompt = get_prompt_6_weave_v20_llama3b()

    all_length_acc = defaultdict(list)

    before_len = 0

    pbar = tqdm(dataloader)
    count = 0
    for data in pbar:

        count += 1


        model.eval()
        with torch.no_grad():
            # query = prefix_prompt + data["text"][0] + "\n\n"
            query = prefix_prompt.format(data["text"][0])
            inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
            input_ids = inputs_token.input_ids
            print("input token length: {}".format(len(input_ids[0])))


            outputs = model.generate(input_ids, **model_custom_config)
            # outputs = model.generate(input_ids)

            response = tokenizer.decode(outputs[0])[len(query):]
            print("response: {}".format(response))
            print("target: {}".format(data["target"][0]))

            acc = compare_retrieval_acc(response, data["target"][0])

            if acc == 1:
                print("success")
            else:
                print("failed")

            token_length = int(data["token_length"][0])
            all_length_acc[token_length].append(acc)

        # raise NotImplementedError("Stop")

        all_mean_var_res = {
            token_length: {
                "mean": np.nanmean(np.array(record)),
                "var": np.nanvar(np.array(record))
            }
            for token_length, record in all_length_acc.items()
        }

        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}", "wb") as f:
            pickle.dump({"length_mean_var": all_mean_var_res}, f)

        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{'record_'+args.save_file}", "wb") as f:
            pickle.dump({"all_length_acc": all_length_acc}, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/vicuna/vicuna-13b-v1.3")  # mosaicml-mpt-7b
    parser.add_argument("--method", type=str, default="weave-v20")
    parser.add_argument("--dataset", type=str, default="../datas/passkey-data_dup-10_answer-6bit.json")
    parser.add_argument("--save_file", type=str, default="retrieval_weave-v20_vicuna-13b-v1.3_test.pkl")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--cuda", type=str, default="0")
    parser.add_argument("--hard_cuda", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)


    args = parser.parse_args()
    set_seed(args.seed)
    main(args)

















