import numpy as np
import torch
import random
import os
from tqdm import tqdm

import sys
import argparse
from datasets import load_dataset
from dataclasses import dataclass, asdict
import json

from pprint import pprint
import itertools
import src
from src import tasks_repo as tasks_mapping
from src import Evaluator
from src.fuzzycopy import generator
from src.utils import load_json, Timer, time_str
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tasks", type=str, default="a_level,a_level_symbol")
    parser.add_argument("--model_id", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--batch_size", type=int, default=3)
    parser.add_argument("--num_fewshot", type=int, default=1)
    parser.add_argument("--num_fewshot2", type=int, default=1)
    parser.add_argument("--kc", type=int, default=1)
    parser.add_argument("--description_dict_path", default="./data/description.json")
    parser.add_argument(
        "--limit",
        type=float,
        default=100,
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
    parser.add_argument("--output_base_path", type=str, default=None)
    return parser.parse_args()


def inference(
        task_name,
        evaluator,
        n_shot,
        n_shot2,
        kc,
        description,
        limit,
        output_base_path,
    ):

    print(f"inference for {task_name}...")
    timer = Timer()
    
    task = tasks_mapping[task_name]()

    # print(task.dataset)
    dataset = task.dataset
    
    train_examples = task.training_docs()
    test_examples = task.validation_docs()


    ### sample ids
    rnd = random.Random()
    rnd.seed(2925)

    test_examples = list(test_examples)
    rnd_test_shuffle = random.Random()
    rnd_test_shuffle.seed(2925)
    rnd_test_shuffle.shuffle(test_examples) # TODO: tem zhuoyan
    print("**************")
    print(len(test_examples))
    print("**************")

    kwangs_fewshot_context = {}

    if limit is not None:
        limit = int(len(test_examples) * limit) if limit < 1.0 else int(limit)

    accs = 0
    accs1 = 0
    accs2 = 0
    accs3 = 0
    results = []

    
    g = 0
    for doc_id, doc in enumerate(itertools.islice(test_examples, 0, limit)):
        g += 1
        if task_name == "ab_level":
            kwangs_fewshot_context.update({"doc_id": doc_id})

        if (doc_id+1) % 10 == 0:
            print(f"===== {task_name} =========== {doc_id+1}")

        qn = task.fewshot_context(doc, kc, n_shot, n_shot2, rnd, description=description, **kwangs_fewshot_context)

        prompt = qn
        #print(qn)
        #print(prompt)
        #print("**************************************")
        ra = ["verbsingle_upper_com_incontext_re1","verbsingle_plusOne_com_incontext_re1","verbpair_swap_com_incontext_re1","oppopair_swap_com_incontext_re1","oppoverb_com_incontext_re1"]
        answer = task.doc_to_target(doc)
        #answer1 = task.doc_to_target1(doc)
        #answer2 = task.doc_to_target2(doc)
        #answer3 = task.doc_to_target3(doc)
        # print(qn)

        # print(answer)
        solution = ""
        result = ""
        if "_expcot" in task_name:
            solution = evaluator.prompt_to_solutionexp(prompt, seq_len = 128)
            result = evaluator.eval2exp(prompt, answer, solution, seq_len = 128)#, hiddenlayers, attentions
        elif "_cot" in task_name:
            solution = evaluator.prompt_to_solution(prompt, seq_len = 128)
            result = evaluator.eval2cot(prompt, answer, solution, seq_len = 128)
        else:
            solution = evaluator.prompt_to_solution(prompt, seq_len = 128)
            result = evaluator.eval2(prompt, answer, solution, seq_len = 128)#, hiddenlayers, attentions
        #result1 = evaluator.eval2(prompt, answer1, solution, seq_len = 20)#, hiddenlayers, attentions
        #result2 = evaluator.eval2(prompt, answer2, solution, seq_len = 20)#, hiddenlayers, attentions
        #result3 = evaluator.eval2(prompt, answer3, solution, seq_len = 20)#, hiddenlayers, attentions
        #for ijk in range(len(hiddenlayers)):
        #    upc = hiddenlayers[ijk].cpu().numpy()
        #    np.save("matrix1/layer" + str(ijk) + ".npy", upc)
        last_place = 0
        start = 0
        #last_place = prompt.find('input', start, len(prompt))
        '''while start < len(prompt):
            last_place = prompt.find('input:', start, len(prompt))
            if last_place == -1:
                break
            start = last_place + 1
        #print(start - 1)
        ending = prompt[0:start - 1]
        sr = evaluator.translation(ending)
        #print(len(sr))
        a = np.zeros((len(sr) * 32))#32, 1, 32, token, token
        b = np.zeros((32, len(sr) * 32))#32, 1, 32, token, token
        #vecs = []
        for i in range(32):
            for k in range(32):
                a[k * len(sr): (k + 1) * len(sr)] = attentions[i][0][k][-1][0:len(sr)].cpu().numpy()
            a /= 32.0
            b[i] = a
            #print(a)
        #vecsnew = np.array(vecs)
        #print(vecsnew.shape)
        #print(b.shape)
        #np.save("matrix24/no" + str(g) + ".npy", b)'''
        accs += result.accuracy
        #accs1 += result1.accuracy
        #accs2 += result2.accuracy
        #accs3 += result3.accuracy
        results.append(asdict(result))


    results = [{"acc": accs / min(len(test_examples),limit)}] + results#,{"acc1": accs1 / min(len(test_examples),limit)},{"acc2": accs2 / min(len(test_examples),limit)}] + results#,{"acc3": accs3 / min(len(test_examples),limit)}
    if output_base_path:
        file_path = output_base_path
    else:
        file_path = f"./output/debug"
    

    file_name = f"{task_name}"
        

    if not os.path.exists(file_path):
        os.makedirs(file_path)
    print(f"save to {file_path}/{file_name}")


    with open(
        f"{file_path}/{file_name}.json",
        "w",
    ) as f:
        json.dump(results, f, indent = 4)

    print(f"{task_name} | {evaluator.get_model_id()} | use time {time_str(timer.end())}")


def main():
    args = parse_args()
    task_names = args.tasks.split(",")
    print("this run contain tasks: ", task_names)
    description_dict = load_json(args.description_dict_path)
    timer = Timer()


    evaluator = Evaluator(args.model_id, device)
    
    for task_name in task_names:
        description = description_dict.get(task_name) or ""
        inference(
            task_name = task_name,
            evaluator = evaluator,
            n_shot = args.num_fewshot,
            n_shot2 = args.num_fewshot2,
            kc = args.kc,
            description = description,
            limit = args.limit,
            output_base_path = args.output_base_path
        )
    
    print(f"{args.model_id} all done | use time {time_str(timer.end())}")

if __name__ == '__main__':
    main()