import numpy as np
import torch
import random
import os
from tqdm import tqdm
import copy
import sys
import argparse
from datasets import load_dataset
from dataclasses import dataclass, asdict
import json

from pprint import pprint
import itertools
import src
from src import tasks_repo as tasks_mapping
from src import Evaluator
from src.fuzzycopy import generator
from src.utils import load_json, Timer, time_str
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tasks", type=str, default="a_level,a_level_symbol")
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--batch_size", type=int, default=3)
    parser.add_argument("--num_fewshot", type=int, default=1)
    parser.add_argument("--num_fewshot2", type=int, default=1)
    parser.add_argument("--kc", type=int, default=1)
    parser.add_argument("--description_dict_path", default="./data/description.json")
    parser.add_argument(
        "--limit",
        type=float,
        default=100,
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
    parser.add_argument("--output_base_path", type=str, default=None)
    return parser.parse_args()


def inference(
        task_name,
        evaluator,
        n_shot,
        n_shot2,
        kc,
        description,
        limit,
        output_base_path,
    ):

    print(f"inference for {task_name}...")
    timer = Timer()
    os.system("mkdir " + "matrices/" + task_name)
    
    task = tasks_mapping[task_name]()

    # print(task.dataset)
    dataset = task.dataset
    
    train_examples = task.training_docs()
    test_examples = task.validation_docs()


    ### sample ids
    rnd = random.Random()
    rnd.seed(1234)

    test_examples = list(test_examples)
    rnd_test_shuffle = random.Random()
    rnd_test_shuffle.seed(1234)
    rnd_test_shuffle.shuffle(test_examples) # TODO: tem zhuoyan
    print("**************")
    print(len(test_examples))
    print("**************")

    kwangs_fewshot_context = {}

    if limit is not None:
        limit = int(len(test_examples) * limit) if limit < 1.0 else int(limit)

    accs = 0
    accs1 = 0
    accs2 = 0
    accs3 = 0
    results = []
    print("qlens")
    ptest = "How are you"
    _, _, attentiontest = evaluator.prompt_to_solution2(ptest, seq_len = 3)
    print(len(attentiontest))
    print(len(attentiontest[0][0]))
    TOTAL_Layers = len(attentiontest)
    TOTAL_Heads = len(attentiontest[0][0])

    g = 0
    for doc_id, doc in enumerate(itertools.islice(test_examples, 0, limit)):
        g += 1
        if task_name == "ab_level":
            kwangs_fewshot_context.update({"doc_id": doc_id})

        if (doc_id+1) % 10 == 0:
            print(f"===== {task_name} =========== {doc_id+1}")

        qn, raw_prompt = task.fewshot_context(doc, kc, n_shot, n_shot2, rnd, description=description, **kwangs_fewshot_context)#1 for t1, 2 for t2 and 0 for comp
        totalk = kc + n_shot + n_shot2
        #print(indexes)
        prompt = qn
        #prompt = prompt.strip()
        #print(qn)
        #print(prompt)
        #print("**************************************")
        answer = task.doc_to_target(doc)

        solution, outputo, attentions = evaluator.prompt_to_solution2(prompt, seq_len = 28)
        result = evaluator.eval3(prompt, answer, solution, seq_len = 28)#, hiddenlayers, attentions
        for c in range(TOTAL_Layers):
            attentions[c] = np.array(attentions[c].to(torch.float32))
        last_place = 0
        start = 0
        #last_place = prompt.find('input', start, len(prompt))
        while start < len(prompt):
            last_place = prompt.find('input:', start, len(prompt))
            if last_place == -1:
                break
            start = last_place + 1
            #print(start - 1)
        ending = prompt[0:start - 1]
        sr = evaluator.translation(ending)
        #sr = evaluator.translation(raw_prompt)
        #print(sr)
        #print(evaluator.translation(qn))
        #assert(False)
        #print(prompt)
        #print(len(sr))#32, 1, 32, token, token
        #print(len(evaluator.translation(prompt.strip())))
        b = np.zeros((TOTAL_Layers, len(sr) * TOTAL_Heads))#32, 1, 32, token, token
        a = np.zeros((len(sr) * TOTAL_Heads))
        #vecs = []
        for i in range(TOTAL_Layers):
            #a = np.zeros((len(sr) * TOTAL_Heads))
            #print(attentions[0][0][0][-1])
            for k in range(TOTAL_Heads):
                a[k * len(sr): (k + 1) * len(sr)] = attentions[i][0][k][-1][0:len(sr)]#.cpu().numpy()
            a /= 32.0
            b[i] = a#copy.deepcopy(a)
            #print(b[0][0] == b[i][0])
            #print(b[i])
            #print(a)
        #vecsnew = np.array(vecs)
        #print(vecsnew.shape)
        #print(b.shape)
        #frint(fr)
        np.save("matrices/" + task_name + "/no" + str(g) + ".npy", b)
        #totaling += toks
        accs += result.accuracy
        #accs3 += result3.accuracy
        results.append(asdict(result))

    results = [{"acc": accs / min(len(test_examples),limit)}] + results#,{"acc3": accs3 / min(len(test_examples),limit)}
    if output_base_path:
        file_path = output_base_path
    else:
        file_path = f"./output/debug"
    file_name = f"{task_name}"
        

    if not os.path.exists(file_path):
        os.makedirs(file_path)
    print(f"save to {file_path}/{file_name}")
    
    with open(
        f"{file_path}/{file_name}.json",
        "w",
    ) as f:
        json.dump(results, f, indent = 4)

    print(f"{task_name} | {evaluator.get_model_id()} | use time {time_str(timer.end())}")


def main():
    args = parse_args()
    task_names = args.tasks.split(",")
    print("this run comtain tasks: ", task_names)
    description_dict = load_json(args.description_dict_path)
    timer = Timer()


    evaluator = Evaluator(args.model_id, device)
    
    for task_name in task_names:
        description = description_dict.get(task_name) or ""
        inference(
            task_name = task_name,
            evaluator = evaluator,
            n_shot = args.num_fewshot,
            n_shot2 = args.num_fewshot2,
            kc = args.kc,
            description = description,
            limit = args.limit,
            output_base_path = args.output_base_path
        )
    
    print(f"{args.model_id} all done | use time {time_str(timer.end())}")

if __name__ == '__main__':
    main()
