import json
import torch
import torch.nn as nn
import sys
sys.path.append('..')
from modeling import load_gptj, GPTJWrapper, load_gpt2xl, load_gpt2, GPT2Wrapper, LambdaLayer, BloomIdentityLayer
from bigbench_tasks import multiple_choice_query, PromptBuilder
from rich.progress import track
from console import console, timer
import numpy as np
from utils import load_json, get_probs_and_mrrs, from_layer_logits_to_prob_distros #model, logits, answer ; logits
import random

random.seed(42)

#MAX_ALLOWED_LEN = 153

class PastTenseBuilder(PromptBuilder):

    def __init__(self, dataset):
        super().__init__(dataset, example_separator=dataset['few_shot_example_separator']) #this is a period+space: '. '

    def __len__(self):
        return len(self.dataset['examples'])

    def build_mc_prompt(self, index, include_answer=False):
        return None

    def build_open_prompt(self, index, include_answer=False):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: a string prompt, the answer as a string
        """
        present_verb, past_verb = self.dataset['examples'][index]['input'], self.dataset['examples'][index]['target']
        inputs = f"""Today I {present_verb}. Yesterday I"""
        prompt = inputs
        if include_answer:
            prompt+=' '+past_verb
        return prompt, ' '+past_verb


def build_ood_example(index, in_domain_prompter, out_domain_prompter, nshots):
    #Linguistic mapping repo writes it this way:
    #In-domain (irregular): Today I eat. Yesterday I ate. Today I sleep. Yesterday I ___ [target: slept]
    #Out-of-domain: Today I eat. Yesterday I ate. Today I cough. Yesterday I ___ [target: coughed]
    ex_sep = in_domain_prompter.example_separator
    prompt = ''
    for i in range(index-nshots, index):
        p, _ = in_domain_prompter.build_open_prompt(i % len(in_domain_prompter), include_answer=True)
        prompt+=p+ex_sep
    final_prompt, gt = out_domain_prompter.build_open_prompt(index % len(out_domain_prompter))
    prompt+=final_prompt
    #console.print("OOD PROMPT", prompt)
    return prompt, gt

def generate_id_ans(model, in_domain_prompter, idx):
    prompt, gt = in_domain_prompter.nshot_open_prompt(nshots, idx)
    #prompt_ids = model.tokenize(prompt)
    #tokenized_gt = model.tokenize(gt)[0]
    #tokenized_pres = model.tokenize(' '+in_domain_prompter.dataset['examples'][idx]['input'])[0]
    #if len(tokenized_gt)>1 and len(tokenized_pres)>1 and tokenized_gt[0]==tokenized_pres[0]:
    #    #print(f'{gt} gets tokenized as {tokenized_gt}, which is longer than one token and could confound results')
    #    return None

    return generate_ans(model, prompt, gt)


def generate_ood_ans(model, in_domain_prompter, out_domain_prompter, idx):
    #making the prompt
    prompt, gt = build_ood_example(idx, in_domain_prompter, out_domain_prompter, nshots)
    #tokenized_gt = model.tokenize(gt)[0]
    #tokenized_pres = model.tokenize(' '+out_domain_prompter.dataset['examples'][idx % len(out_domain_prompter)]['input'])[0]
    #if len(tokenized_gt)>1 and len(tokenized_pres)>1 and tokenized_gt[0]==tokenized_pres[0]:
    #    #print(f'{gt} gets tokenized as {tokenized_gt}, which is longer than one token and could confound results')
    #    return None

    return generate_ans(model, prompt, gt)


def generate_ans(model, prompt, gt):
    targets = [gt]#prompter.get_mc_targets(idx)
    #print("TARGETS", targets)
    gt_idx = targets.index(gt)
    #print(prompt, idx, gt)
    #running prompt thru model
    prompt_ids = model.tokenize(prompt)
    logits = model.get_layers(prompt_ids)
    
    probs_results = {}
    rrs_results   = {}

    for i in range(len(targets)):
        tgt = targets[i]
        probs, ffn_rrs = get_probs_and_mrrs(model, logits, tgt)

        probs_results[i] = probs.tolist()
        rrs_results[i] = ffn_rrs.tolist()


    top10_per_layer = model.topk_per_layer(logits, 10)
    prompt_results = {'inputs':prompt, 'targets':targets, 'answer':gt, 'answer_idx':gt_idx, 'probs':probs_results, 'rrs':rrs_results, 'top10_per_layer':top10_per_layer}
    return prompt_results

def filter_out_ood_words_with_multiple_tokens(prompter, o_prompter):
    num_examples = len(prompter)
    for i in range(num_examples):
        prompt, gt = build_ood_example(i, prompter, o_prompter, nshots)
        tokenized_gt = model.tokenize(gt)[0]
        tokenized_pres = model.tokenize(' '+o_prompter.dataset['examples'][idx]['input'])[0]
        if len(tokenized_gt)>1 and len(tokenized_pres)>1 and tokenized_gt[0]==tokenized_pres[0]:
            del prompter.dataset['examples'][i]
            #print(len(prompter.dataset['examples']))

        
def filter_out_id_words_with_multiple_tokens(prompter):
    num_examples = len(prompter)
    for i in range(num_examples):
        prompt, gt = prompter.build_open_prompt(i)
        tokenized_gt = model.tokenize(gt)[0]
        tokenized_pres = model.tokenize(' '+prompter.dataset['examples'][i]['input'])[0]
        if len(tokenized_gt)>1 and len(tokenized_pres)>1 and tokenized_gt[0]==tokenized_pres[0]:
            del prompter.dataset['examples'][i]
            #print(len(prompter.dataset['examples']))


def get_ood_open_generations(model, in_domain_dataset, out_domain_dataset):
    all_output = []
    in_domain_prompter = PastTenseBuilder(in_domain_dataset)
    out_domain_prompter= PastTenseBuilder(out_domain_dataset)
    
    #filter_out_id_words_with_multiple_tokens(out_domain_prompter)

    with torch.no_grad():
        for i in track(range(max(len(out_domain_prompter), len(in_domain_prompter)) ), description='iterating...'):
            output = generate_ood_ans(model, in_domain_prompter, out_domain_prompter, i)
            if output is None:
                continue
            json_out  = output
            all_output.append(json_out)

    return all_output

def get_id_open_generations(model, in_domain_dataset):
    all_output = []
    in_domain_prompter = PastTenseBuilder(in_domain_dataset)
    #filter_out_id_words_with_multiple_tokens(in_domain_prompter)

    with torch.no_grad():
        for i in track(range(len(in_domain_prompter)), description='iterating...'):
            output = generate_id_ans(model, in_domain_prompter, i)
            if output is None:
                continue
            json_out  = output
            all_output.append(json_out)

    return all_output

def save_output(output, fname):
    console.print("NOT SAVING")
    #with open(fname, 'w') as fp:
    #    json.dump(output, fp, indent=4)

def count_correct(data):
    correct= 0
    total= len(data)
    wrong = []
    for d in data:
        target_idx = d['answer_idx']
        #console.print(d['answer'], target_idx, d['rrs'], 'keys', d['rrs'].keys())
        ans = d['rrs'][target_idx][-1]
        if ans==1:
            correct+=1
    return correct


def print_acc_and_save(output, fname):
    total_seen = len(output)
    total_correct = count_correct(output)
    console.print(f"Total correct/total seen: {total_correct}/{total_seen} = {total_correct / total_seen}")
    save_output(output, fname)
    return total_correct, total_seen



if __name__ == "__main__":

    model_name = sys.argv[1]
    nshots = sys.argv[2]
    dataset_name=sys.argv[3]
    do_intervention=sys.argv[4]
    intervene = False
    ablate = False
    if do_intervention=='intervene':
        intervene = True
    elif do_intervention=='ablate':
        intervene = False
        ablate = True
    console.print(model_name, nshots, 'shot(s)', dataset_name, 'intervene?', intervene)
    timer_task = timer.add_task("Loading model")
    with timer:
        if 'gpt2' in model_name:#model_name == 'gpt2-xl':
            model, tokenizer = load_gpt2(model_name)
            model = GPT2Wrapper(model, tokenizer)#GPTJWrapper(gptj, tokenizer)
            if intervene or ablate:
                if dataset_name=='past_tense':
                    #o_func = np.load("../o_past_ffn19_gpt2_medium.npy") #irregular
                    o_func = np.load("../regular_o_past_ffn18_gpt2_medium.npy") #regular
                    o_func = torch.tensor(o_func).cuda().half()
                for i in range(18,19):
                    console.print("interv. layer", i)
                    if intervene:
                        model.model.transformer.h[i].mlp = LambdaLayer(lambda x: o_func)
                    elif ablate:
                        model.model.transformer.h[i].mlp = nn.Identity()
        elif 'gptj' == model_name:
            model, tokenizer = load_gptj()
            model = GPTJWrapper(model, tokenizer)
        elif 'bloom' in model_name:
            model, tokenizer = load_bloom(model_name)
            model = BloomWrapper(model, tokenizer)
    timer.stop_task(timer_task)
    timer.update(timer_task, visible=False)

    starting_num_params = sum([param.numel() for param in model.model.parameters() if param.requires_grad])
 
    regular_data = load_json(f"{dataset_name}_regular.json")
    random.shuffle(regular_data['examples'])
    irregular_data = load_json(f"{dataset_name}_irregular.json")
    random.shuffle(irregular_data['examples'])

    #get_mc_generations(model, dataset)
    num_layers=len(model.model.transformer.h)

    ood_setting = 'irreg_to_reg'#'irreg_to_reg'

    if intervene:
        exp_setting = 'reg_ofunc'
    elif ablate:
        exp_setting = 'ablate'
    else:
        exp_setting = 'cntrl'
    if type(nshots)==str and ',' in nshots:
        split = nshots.split(",")
        if split[-1]=='':
            del split[-1]
        rnshots = [int(n) for n in split]
        for n in rnshots:
            random.seed(42)
            regular_data = load_json(f"{dataset_name}_regular.json")
            random.shuffle(regular_data['examples'])
            irregular_data = load_json(f"{dataset_name}_irregular.json")
            random.shuffle(irregular_data['examples'])
            #console.print("NSHOTS", n)
            nshots = n
            reg_in_domain_accs = []
            irreg_in_domain_accs=[]
            out_domain_accs= []
            if n == 0:
                repeats = 1
            else:
                repeats = 5
            for repeat1 in range(repeats):
                output = get_ood_open_generations(model, irregular_data, regular_data)
                #console.print("Out of Domain Accuracy:")
                total_correct, total_seen = print_acc_and_save(output, f'{dataset_name}_{exp_setting}_{ood_setting}_ood_{model_name}_{nshots}.json')
                acc = total_correct / total_seen
                out_domain_accs.append(acc)
                for repeat in range(1):
                    output = get_id_open_generations(model, irregular_data)
                    #console.print("Irregular in-domain Accuracy:")
                    total_correct, total_seen =print_acc_and_save(output, f'{dataset_name}_{exp_setting}_irregular_in_domain_{model_name}_{nshots}.json')
                    acc = total_correct/total_seen
                    irreg_in_domain_accs.append(acc)
                
                output = get_id_open_generations(model, regular_data)
                #console.print("Regular in-domain Accuracy:")
                total_correct, total_seen = print_acc_and_save(output, f'{dataset_name}_{exp_setting}_regular_in_domain_{model_name}_{nshots}.json')
                acc = total_correct/total_seen
                reg_in_domain_accs.append(acc)

                random.shuffle(regular_data['examples'])
                random.shuffle(irregular_data['examples'])

            console.print("NSHOTS", n)
            console.print("OOD Accuracy", out_domain_accs, 'mean:', np.mean(out_domain_accs))
            console.print("Irregular Accuracy", irreg_in_domain_accs, 'mean:', np.mean(irreg_in_domain_accs))
            console.print("Regular Accuracy", reg_in_domain_accs, 'mean:', np.mean(reg_in_domain_accs))

            
