import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict

import torch.types
from tqdm import tqdm

from utils.utils import set_seed
from datas.get_data import get_data
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F
from utils.utils import compare_retrieval_acc

import numpy as np
import os

from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

model_custom_config = {
    "max_new_tokens": 100,
    "temperature": 0.1,
    "top_p": 0.9
}

default_dict = {
                    "record": [],
                    "bleu": [],
                    "rouge-1": [],
                    "rouge-2": [],
                    "rouge-l": [],
                    "bleu-mean": 0.0,
                    "rouge-1-mean": 0.0,
                    "rouge-2-mean": 0.0,
                    "rouge-l-mean": 0.0,
                    "bleu-var": 0.0,
                    "rouge-1-var": 0.0,
                    "rouge-2-var": 0.0,
                    "rouge-l-var": 0.0
                }

def main(args):

    dataset = get_data(args.dataset, args)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method)

    prefix_prompt = get_promt(args.model_path)

    all_length_acc = defaultdict(list)
    all_length_score = defaultdict(dict)
    rouge = Rouge()

    pbar = tqdm(dataloader)
    count = 0
    for data in pbar:
        # if count > 50:
        #     break
        count += 1
        model.eval()
        with torch.no_grad():
            # query = prefix_prompt + data["text"][0] + "\n\n"
            query = data["text"][0]
            inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
            input_ids = inputs_token.input_ids
            print("input token length: {}".format(len(input_ids[0])))
            total_len = len(input_ids[0])

            start = args.start
            step = args.step
            while start < args.end and start+100 <= total_len:

                input_token = input_ids[:, :start]
                target_token = input_ids[0, start: start + 100]

                print("outputs = model.generate(input_token, **model_custom_config)")
                outputs = model.generate(input_token, **model_custom_config)

                response = tokenizer.decode(outputs[0, start: start + 100])
                target = tokenizer.decode(target_token)

                all_length_score[start] = all_length_score.get(start, default_dict)

                all_length_score[start]["record"].append(
                    {
                        "response": response,
                        "target": target
                    }
                )

                reference = [target]
                generated = [response]
                # 计算rouge

                print("rouge_score = rouge.get_scores(hyps=generated, refs=reference)")
                rouge_score = rouge.get_scores(hyps=generated, refs=reference)
                # print(rouge_score[0]["rouge-1"])
                # print(rouge_score[0]["rouge-2"])
                # print(rouge_score[0]["rouge-l"])
                all_length_score[start]["rouge-1"].append(rouge_score[0]["rouge-1"])
                all_length_score[start]["rouge-2"].append(rouge_score[0]["rouge-2"])
                all_length_score[start]["rouge-l"].append(rouge_score[0]["rouge-l"])
                # 计算BLEU
                reference = [words.split(" ") for words in reference]
                generated = generated[0].split(" ")
                bleu = sentence_bleu(reference, generated)
                print("BLEU Score:", bleu)
                all_length_score[start]["bleu"].append(bleu)

                # all_length_score[start]["rouge-1-mean"] = np.nanmean(np.array(all_length_score[start]["rouge-1"]))
                # all_length_score[start]["rouge-1-var"] = np.nanmean(np.array(all_length_score[start]["rouge-1"]))
                # all_length_score[start]["rouge-2-mean"] = np.nanmean(np.array(all_length_score[start]["rouge-2"]))
                # all_length_score[start]["rouge-2-var"] = np.nanmean(np.array(all_length_score[start]["rouge-2"]))
                # all_length_score[start]["rouge-l-mean"] = np.nanmean(np.array(all_length_score[start]["rouge-l"]))
                # all_length_score[start]["rouge-l-var"] = np.nanmean(np.array(all_length_score[start]["rouge-l"]))
                all_length_score[start]["bleu-mean"] = np.nanmean(np.array(all_length_score[start]["bleu"]))
                all_length_score[start]["bleu-var"] = np.nanmean(np.array(all_length_score[start]["bleu"]))

                start += step

                # 优化内存使用
                del input_token
                del target_token
                torch.cuda.empty_cache()

            #     break
            # break


        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}", "wb") as f:
            pickle.dump({"all_length_score": all_length_score}, f)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/llama2-7b")
    parser.add_argument("--method", type=str, default="leaky-rerope")
    parser.add_argument("--dataset", type=str, default="/data/persist/dataset/gov_report/test.txt")
    parser.add_argument("--save_file", type=str, default="rouge_llama2-7b_test.pkl")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--cuda", type=str, default="1")
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--start", type=int, default=1024)
    parser.add_argument("--step", type=int, default=1024)
    parser.add_argument("--end", type=int, default=32*1024)

    args = parser.parse_args()
    set_seed(args.seed)
    main(args)

















