import numpy as np
import torch
import random
import os
from tqdm import tqdm

import sys
import argparse
from datasets import load_dataset
from dataclasses import dataclass, asdict
import json

from pprint import pprint
import itertools
import src
from src import tasks_repo as tasks_mapping
from src import Evaluator
from src.fuzzycopy import generator
from src.utils import load_json, Timer, time_str
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tasks", type=str, default="a_level,a_level_symbol")
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--batch_size", type=int, default=3)
    parser.add_argument("--num_fewshot", type=int, default=1)
    parser.add_argument("--num_fewshot2", type=int, default=1)
    parser.add_argument("--kc", type=int, default=1)
    parser.add_argument("--description_dict_path", default="./data/description.json")
    parser.add_argument(
        "--limit",
        type=float,
        default=100,
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
    parser.add_argument("--output_base_path", type=str, default=None)
    return parser.parse_args()


def inference(
        task_name,
        evaluator,
        n_shot,
        n_shot2,
        kc,
        description,
        limit,
        output_base_path,
    ):

    print(f"inference for {task_name}...")
    timer = Timer()
    
    task = tasks_mapping[task_name]()

    # print(task.dataset)
    dataset = task.dataset
    
    train_examples = task.training_docs()
    test_examples = task.validation_docs()


    ### sample ids
    rnd = random.Random()
    rnd.seed(5211)

    test_examples = list(test_examples)
    rnd_test_shuffle = random.Random()
    rnd_test_shuffle.seed(5211)
    rnd_test_shuffle.shuffle(test_examples) # TODO: tem zhuoyan
    print("**************")
    print(len(test_examples))
    print("**************")

    kwangs_fewshot_context = {}

    if limit is not None:
        limit = int(len(test_examples) * limit) if limit < 1.0 else int(limit)

    accs = 0
    accs1 = 0
    accs2 = 0
    accs3 = 0
    results = []

    
    g = 0
    totaling = np.zeros((32,4,4),dtype=float)
    for doc_id, doc in enumerate(itertools.islice(test_examples, 0, limit)):
        g += 1
        if task_name == "ab_level":
            kwangs_fewshot_context.update({"doc_id": doc_id})

        if (doc_id+1) % 10 == 0:
            print(f"===== {task_name} =========== {doc_id+1}")

        qn,parta,partb,partc,partexp = task.fewshot_context(doc, kc, n_shot, n_shot2, rnd, description=description, **kwangs_fewshot_context)

        prompt = qn
        #print(qn)
        #print(prompt)
        #print("**************************************")
        answer = task.doc_to_target(doc)
        #answer1 = task.doc_to_target1(doc)
        #answer2 = task.doc_to_target2(doc)
        #answer3 = task.doc_to_target3(doc)
        # print(qn)

        # print(answer)
        solution, outputo, attentions = evaluator.prompt_to_solution2(prompt, seq_len = 20)
        result = evaluator.eval3(prompt, answer, solution, seq_len = 20)#, hiddenlayers, attentions
        #result1 = evaluator.eval2(prompt, answer1, solution, seq_len = 20)#, hiddenlayers, attentions
        #result2 = evaluator.eval2(prompt, answer2, solution, seq_len = 20)#, hiddenlayers, attentions
        #result3 = evaluator.eval2(prompt, answer3, solution, seq_len = 20)#, hiddenlayers, attentions
        #for ijk in range(len(hiddenlayers)):
        #    upc = hiddenlayers[ijk].cpu().numpy()
        #    np.save("matrix1/layer" + str(ijk) + ".npy", upc)
        attentions = np.array(attentions)
        #print(qn)
        #print(parta+partb+partc+partexp)
        print((qn == parta+partb+partc+partexp))
        #tokdcpn = evaluator.translation(partdcpn)
        toka = evaluator.translation(parta)
        tokb = evaluator.translation(partb)
        tokc = evaluator.translation(partc)
        tokexp = evaluator.translation(partexp)
        #lendcpn = len(tokdcpn)
        lena = len(toka)
        lenb = len(tokb) - 1
        lenc = len(tokc) - 1
        lenexp = len(tokexp) - 1
        toks = np.zeros((32,4,4),dtype=float)
        #p = 0.0
        #for i in range(len(attentions[0][0][0][0])):
        #    p += attentions[0][0][0][5][i]
        #print(p)
        '''last_place = 0
        start = 0
        #last_place = prompt.find('input', start, len(prompt))
        while start < len(prompt):
            last_place = prompt.find('input:', start, len(prompt))
            if last_place == -1:
                break
            start = last_place + 1
        #print(start - 1)
        ending = prompt[0:start - 1]
        sr = evaluator.translation(ending)
        #print(len(sr))
        a = np.zeros((len(sr) * 32))#32, 1, 32, token, token
        b = np.zeros((32, len(sr) * 32))#32, 1, 32, token, token
        #vecs = []
        for i in range(32):
            for k in range(32):
                a[k * len(sr): (k + 1) * len(sr)] = attentions[i][0][k][-1][0:len(sr)].cpu().numpy()
            a /= 32.0
            b[i] = a
            #print(a)
        #vecsnew = np.array(vecs)
        #print(vecsnew.shape)
        #print(b.shape)
        #np.save("matrix24/no" + str(g) + ".npy", b)'''
        pob = []
        pob.append(0)
        pob.append(0 + lena)
        pob.append(0 + lena + lenb)
        pob.append(0 + lena + lenb + lenc)
        pob.append(0 + lena + lenb + lenc + lenexp)
        for i in range(1,5):
            for j in range(1,5):
                for ia in range(pob[i-1], pob[i]):
                    for ja in range(pob[j-1], pob[j]):
                        for heads in range(32):
                            for layers in range(32):
                                toks[layers][i-1][j-1] += attentions[layers][0][heads][ia][ja]
        for layers in range(32):    
            for i in range(0,4):
                toks[layers][0][i] /= (32 * lena)
                toks[layers][1][i] /= (32 * lenb)
                toks[layers][2][i] /= (32 * lenc)
                toks[layers][3][i] /= (32 * lenexp)
        #print(toks)
        totaling += toks
        accs += result.accuracy
        #accs3 += result3.accuracy
        results.append(asdict(result))

    totaling /= g
    results = [{"acc": accs / min(len(test_examples),limit)}] + results#,{"acc3": accs3 / min(len(test_examples),limit)}
    if output_base_path:
        file_path = output_base_path
    else:
        file_path = f"./output/debug"
    

    file_name = f"{task_name}"
        

    if not os.path.exists(file_path):
        os.makedirs(file_path)
    print(f"save to {file_path}/{file_name}")
    np.save(f"{file_path}/{file_name}_metatoken", totaling)
    with open(
        f"{file_path}/{file_name}.json",
        "w",
    ) as f:
        json.dump(results, f, indent = 4)

    print(f"{task_name} | {evaluator.get_model_id()} | use time {time_str(timer.end())}")


def main():
    args = parse_args()
    task_names = args.tasks.split(",")
    print("this run comtain tasks: ", task_names)
    description_dict = load_json(args.description_dict_path)
    timer = Timer()


    evaluator = Evaluator(args.model_id, device)
    
    for task_name in task_names:
        description = description_dict.get(task_name) or ""
        inference(
            task_name = task_name,
            evaluator = evaluator,
            n_shot = args.num_fewshot,
            n_shot2 = args.num_fewshot2,
            kc = args.kc,
            description = description,
            limit = args.limit,
            output_base_path = args.output_base_path
        )
    
    print(f"{args.model_id} all done | use time {time_str(timer.end())}")

if __name__ == '__main__':
    main()
