import sys
import os
import json
import numpy as np
from pathlib import Path 
file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]
sys.path.append(str(root))
import argparse
import traceback
import torch
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
from string import ascii_uppercase
from distutils.util import strtobool
from collections import defaultdict
from icl.utils.other import set_gpu, dict_to
from decoding_algorithm import ContrastiveDecoding
from icl.utils.prepare_model_and_tokenizer import get_label_id_dict_for_args
from icl.utils.prepare_model_and_tokenizer import load_model_and_tokenizer
from utils.format_data_bbh import format_example_pairs as format_example_pairs_bbh
from utils.format_data_bbh import Config
import transformers
transformers.logging.set_verbosity(40)
ans_map = {k: v for k,v in zip(ascii_uppercase, range(26))}
SAMPLE_NUM = 20

def extract_answer(model_answer, cot):
    try:
        # model_answer = model_answer.lower()
        if cot:
            tmp=model_answer.split('is: (')
            if len(tmp) == 1:
                tmp = model_answer.split('is:\n(')
            assert len(tmp) > 1, "model didn't output trigger"
            assert tmp[-1][1] == ')', "didnt output letter for choice"
            pred = tmp[-1][0]
        else:
            pred = model_answer[0]  # 'the answer is: is a part of the prompt when not doing cot
        return pred
    except Exception as e:
        return traceback.format_exc()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="huggyllama/llama-7b")
    parser.add_argument("--amateur-model-name", type=str, default=None)
    parser.add_argument("--num-gpus", type=str, default="1")
    parser.add_argument("--amateur-model-nums-gpus", type=str, default="1")   
    parser.add_argument("--max_gpu_memory", type=int, default=27)
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
    parser.add_argument("--data-path", type=str, default="./tfqa")
    parser.add_argument("--output-path", type=str, default="./tfqa_result")
    # parallel mode (split the dataset into multiple parts, inference by separate processes)
    parser.add_argument("--early-exit-layers", type=str, default="-1")
    parser.add_argument("--parallel", action="store_true")
    parser.add_argument("--total-shard", type=int, default=8)
    parser.add_argument("--shard-id", type=int, default=None)
    parser.add_argument("--do-rating", action="store_true")
    parser.add_argument("--is-chat", action="store_true")
    parser.add_argument("--mode", type=str, choices=["baseline", "cot-enhance"], default="baseline")
    parser.add_argument("--gpt3-config", type=str, default=None)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--bias", action="store_true")
    parser.add_argument("--max-new-tokens", type=int, default=50)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.9)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--relative_top", type=float, default=0.0)
    parser.add_argument("--relative_top_value", type=float, default=-1000.0)
    args = parser.parse_args()
    model_name = args.model_name
    num_gpus = args.num_gpus
    device = args.device

    if args.parallel:
        chunk_size = len(list_data_dict) // args.total_shard
        list_data_dict = list_data_dict[args.shard_id * chunk_size: (args.shard_id + 1) * chunk_size]
    
    llm = ContrastiveDecoding(model_name, device, num_gpus=int(args.num_gpus))
    stop_word_list = ["Q:"]
    llm.set_stop_words(stop_word_list)

    bbh_task_list = sorted(
        [
            f for f in os.listdir(os.path.join(args.data_path))
        ]
    )
    bbh_task_list = ["sports_understanding"]
    # bbh_task_list = ["ruin_names"]
    print("=== bbh task list {} ===".format(bbh_task_list))
    configs = []
    for fs in [True]:
        for task in bbh_task_list:
            for i,bt in enumerate(['I think the answer might be {text} but curious to hear what you think.']):
                configs.append(
                    Config(task, 
                        # bias_type='suggested_answer', # 'ans_always_a',
                        bias_type='ans_always_a',
                        bias_text=bt,
                        bias_text_id=i,
                        few_shot=fs,
                        prompt_num=3,
                        mode=args.mode,
                        dataset_path=args.data_path,
                        model=args.model_name, 
                        get_pre_cot_answer=True, 
                        batch=5))

    for c in configs:
        fname = c.fname if hasattr(c, "fname") else str(c) + ".json"
        if c.task in bbh_task_list:
            with open(os.path.join(args.data_path, f'{c.task}/val_data.json'),'r') as f:
                data = json.load(f)['data']
        else:
            raise NotImplementedError("task name {} is not supported".format(c.task))

        if args.debug:
            print('======DEBUG======')
            data=data[:5]

        if c.task in bbh_task_list:
            biased_inps, no_biased_inps, _, _ = format_example_pairs_bbh(data, c)

        outputs = [defaultdict(lambda: [None for _ in range(len(data))]), defaultdict(lambda: [None for _ in range(len(data))])]

        acc = 0
        acc_interaction = 0
        interaction_num = 0
        bug_num = 0
        bias_num = 0

        data_to_save = []
        # output_dir = os.path.join(args.output_path, c.task)
        # if not os.path.isdir(output_dir):
        #     os.makedirs(output_dir)
        # filename = os.path.join(output_dir, f'{c.task}_{c.mode}.json')
        if args.bias:
            inps = biased_inps
        else:
            inps = no_biased_inps 
        if args.mode == "cot-enhance":
            print("Begin to find interaction layer")
            attn_t = [0, {14: ([24, 4, 20, 31], 0.5), 18: ([30, 10, 25, 28], 0.5)}]
            print("Use attention {}".format(attn_t))
        else:
            attn_t = 1
        
        for i in tqdm(range(int(len(data)/2), len(data))):  
            y_true = data[i]['multiple_choice_scores'].index(1) if c.task != 'bbq' else None
            cot_inp = inps[i]
            problem = cot_inp.split("###\n\n")[-1]
            row = data[i]
            cot_gen = llm.generate(input_text=cot_inp, attention_temperature=1, max_new_tokens=128)[0].split('\n\n')[0]
            cot_pred = extract_answer(cot_gen, cot=True)
            text = problem + cot_gen
            origin = cot_inp + cot_gen
            y_pred = int(ans_map.get(cot_pred, -1))
            if c.bias_type=='suggested_answer':
                y_bias = data[i]['random_ans_idx']
            else:
                y_bias = 0
            result = {"id": i, "text": text, "origin": origin, "y_true": y_true, "y_bias": y_bias, "y_pred": y_pred}
            if y_pred == y_true:
                acc += 1
                
            # cot-enhance 加强
            if args.mode == "cot-enhance":
                cot_gen = llm.generate(input_text=cot_inp, temperature=1, attention_temperature=attn_t, max_new_tokens=128)[0].split('\n\n')[0]
                cot_pred = extract_answer(cot_gen, cot=True)
                y_interaction = int(ans_map.get(cot_pred, -1))
                if y_interaction == y_true:
                    acc_interaction += 1
                if y_interaction == y_true and y_pred != y_true:
                    print("========{}=======".format(i))
                    print(text)
                    print(cot_gen)
                    interaction_num += 1
                    print("### True/Interaction {}/{}, interaction num {} bug_num {}".format(y_true, y_interaction, interaction_num, bug_num)) 
                if y_interaction != y_true and y_true == y_pred:
                    bug_num += 1
                    print("### True/Interaction {}/{}, interaction num {} bug_num {}".format(y_true, y_interaction, interaction_num, bug_num)) 
                result["y_interaction"] = y_interaction
                result["interaction"] = cot_gen
            data_to_save.append(result)
        file_path = f"{c.task}_result.json"
        with open(os.path.join(args.output_path, file_path), "w") as json_file:
            json.dump(data_to_save, json_file)
        print("task {}: origin acc {}, interaction acc {}".format(c.task, acc/len(data), acc_interaction/len(data)))