import json
import torch
import torch.nn as nn
import sys
sys.path.append('..')
from modeling import load_gptj, GPTJWrapper, load_gpt2xl, load_gpt2, GPT2Wrapper, LambdaLayer, BloomIdentityLayer
from bigbench_tasks import multiple_choice_query, PromptBuilder
from rich.progress import track
from console import console, timer
import numpy as np
from utils import load_json, get_probs_and_mrrs, from_layer_logits_to_prob_distros #model, logits, answer ; logits
import random
import argparse

random.seed(42)

#MAX_ALLOWED_LEN = 153

class PastTenseBuilder(PromptBuilder):

    def __init__(self, dataset, uppercase_context_verbs = False, uppercase_answer=False):
        super().__init__(dataset, example_separator=dataset['few_shot_example_separator']) #this is a period+space: '. '
        self.uppercase_context_verbs=uppercase_context_verbs
        self.uppercase_answer = uppercase_answer

    def __len__(self):
        return len(self.dataset['examples'])

    def build_mc_prompt(self, index, include_answer=False):
        return None

    def build_open_prompt(self, index, include_answer=False):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: a string prompt, the answer as a string
        """
        present_verb, past_verb = self.dataset['examples'][index]['input'], self.dataset['examples'][index]['target']
        if self.uppercase_context_verbs:
            present_verb = present_verb.title()
        if self.uppercase_answer:
            past_verb = past_verb.title()
            #past_verb = past_verb.title()
        inputs = f"""Today I {present_verb}. Yesterday I"""
        prompt = inputs
        if include_answer:
            prompt+=' '+past_verb
        return prompt, ' '+past_verb


def generate_id_ans(model, in_domain_prompter, idx):
    prompt, gt = in_domain_prompter.nshot_open_prompt(nshots, idx)
    #prompt_ids = model.tokenize(prompt)
    #tokenized_gt = model.tokenize(gt)[0]
    #tokenized_pres = model.tokenize(' '+in_domain_prompter.dataset['examples'][idx]['input'])[0]
    #if len(tokenized_gt)>1 and len(tokenized_pres)>1 and tokenized_gt[0]==tokenized_pres[0]:
    #    #print(f'{gt} gets tokenized as {tokenized_gt}, which is longer than one token and could confound results')
    #    return None

    return generate_ans(model, prompt, gt)


def generate_ans(model, prompt, gt):
    targets = [gt]#prompter.get_mc_targets(idx)
    #print("TARGETS", targets)
    gt_idx = targets.index(gt)
    #print(prompt, idx, gt)
    #running prompt thru model
    prompt_ids = model.tokenize(prompt)
    logits = model.get_layers(prompt_ids)
    
    probs_results = {}
    rrs_results   = {}

    for i in range(len(targets)):
        tgt = targets[i]
        probs, ffn_rrs = get_probs_and_mrrs(model, logits, tgt)

        probs_results[i] = probs.tolist()
        rrs_results[i] = ffn_rrs.tolist()


    top10_per_layer = model.topk_per_layer(logits, 10)
    prompt_results = {'inputs':prompt, 'targets':targets, 'answer':gt, 'answer_idx':gt_idx, 'probs':probs_results, 'rrs':rrs_results, 'top10_per_layer':top10_per_layer}
    return prompt_results


def get_id_open_generations(model, in_domain_dataset, uppercase_context_verbs, uppercase_answer):
    all_output = []
    in_domain_prompter = PastTenseBuilder(in_domain_dataset, uppercase_context_verbs, uppercase_answer)
    #filter_out_id_words_with_multiple_tokens(in_domain_prompter)

    with torch.no_grad():
        for i in track(range(len(in_domain_prompter)), description='iterating...'):
            output = generate_id_ans(model, in_domain_prompter, i)
            if output is None:
                continue
            json_out  = output
            all_output.append(json_out)

    return all_output

def save_output(output, fname):
    #console.print("NOT SAVING")
    with open(fname, 'w') as fp:
        json.dump(output, fp, indent=4)

def count_correct(data):
    correct= 0
    total= len(data)
    wrong = []
    for d in data:
        target_idx = d['answer_idx']
        #console.print(d['answer'], target_idx, d['rrs'], 'keys', d['rrs'].keys())
        ans = d['rrs'][target_idx][-1]
        if ans==1:
            correct+=1
    return correct


def print_acc_and_save(output, fname):
    total_seen = len(output)
    total_correct = count_correct(output)
    console.print(f"Total correct/total seen: {total_correct}/{total_seen} = {total_correct / total_seen}")
    save_output(output, fname)
    return total_correct, total_seen

def gpt2_medium_interventions(model, scheme, start_layer):
    o_upper =  torch.tensor(np.load("../capitalization_update_gpt2-med.npy")).cuda().half()
    o_past= torch.tensor(np.load("../regular_o_past_ffn18_gpt2_medium.npy")).cuda().half()
    interventions = []
    mlp_clones = {}
    for i,mod in enumerate(scheme):
        model_layer = start_layer+i
        mlp_clones[model_layer] = model.model.transformer.h[model_layer].mlp
        if mod == 'L':
            interventions.append(LambdaLayer(lambda x: -o_upper))
        elif mod == 'L+':
            interventions.append(LambdaLayer(lambda x: mlp_clones[model_layer](x)-o_upper))
        elif mod == 'LP_':
            interventions.append(LambdaLayer(lambda x: -o_upper+o_past))
        elif mod == 'U':
            interventions.append(LambdaLayer(lambda x: o_upper))
        elif mod == 'U+':
            interventions.append(LambdaLayer(lambda x: mlp_clones[model_layer](x)+o_upper))
        elif mod == 'UP_':
            interventions.append(LambdaLayer(lambda x: o_upper+o_past))
        elif mod == 'up_':
            interventions.append(LambdaLayer(lambda x: mlp_clones[model_layer](x)+o_upper+o_past))
        elif mod == 'P':
            interventions.append(LambdaLayer(lambda x: o_past))
        elif mod == 'P+':
            interventions.append(LambdaLayer(lambda x: mlp_clones[model_layer](x)+o_past))
        elif mod == 'A':
            interventions.append(nn.Identity())
        elif mod == 'O':
            interventions.append(None)
    return interventions

def parse_intervention_scheme(model_name, scheme):
    if scheme == '':
        return None

    console.print("Modifications", scheme)
    if model_name == 'gpt2-medium':
        return gpt2_medium_interventions(scheme)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model_name", type=str)
    parser.add_argument("nshots", type=str)
    parser.add_argument("--intervention_scheme", default='', type=str, help='example: "LPU" = lowercase, past tense, uppercase')
    parser.add_argument("--start_layer", type=int, default=18, help='which layer to start the intervention (if applicable)')
    parser.add_argument("--uppercase_context", action='store_true')
    parser.add_argument("--uppercase_answer", action='store_true', help='should the ground truth be capitalized?')
    args = parser.parse_args()
    model_name = args.model_name
    intervention_scheme = args.intervention_scheme.split(',')
    if intervention_scheme[-1]=='':
        del intervention_scheme[-1]
    dataset_name='past_tense'
    start_layer = args.start_layer
    uppercase_answer = args.uppercase_answer
    uppercase_context = args.uppercase_context
    nshots = args.nshots
    interventions=None
    console.print(args)
    timer_task = timer.add_task("Loading model")
    with timer:
        if 'gpt2' in model_name:#model_name == 'gpt2-xl':
            model, tokenizer = load_gpt2(model_name)
            model = GPT2Wrapper(model, tokenizer)#GPTJWrapper(gptj, tokenizer)
            interventions = gpt2_medium_interventions(model, intervention_scheme, start_layer)
        elif 'gptj' == model_name:
            model, tokenizer = load_gptj()
            model = GPTJWrapper(model, tokenizer)
        elif 'bloom' in model_name:
            model, tokenizer = load_bloom(model_name)
            model = BloomWrapper(model, tokenizer)

    #interventions = parse_intervention_scheme(model_name, intervention_scheme)
    if interventions is not None:
        for i in range(start_layer, start_layer+len(interventions)):
            console.print("Intervening at layer", i, "with", intervention_scheme[i-start_layer])
            if interventions[i-start_layer] is None:
                continue
            model.model.transformer.h[i].mlp = interventions[i-start_layer]

    timer.stop_task(timer_task)
    timer.update(timer_task, visible=False)

    #get_mc_generations(model, dataset)
    num_layers=len(model.model.transformer.h)

    if uppercase_context:
        upper_con = 'up_con'
    else:
        upper_con = 'low_con'
    if uppercase_answer:
        upper_ans = 'up_ans'
    else:
        upper_ans = 'low_con'
    if intervention_scheme=='':

        exp_setting = f'cntrl_{upper_con}_{upper_ans}'
    else:
        exp_setting = f'{start_layer}_{"".join(intervention_scheme)}_{upper_con}_{upper_ans}'

    if type(nshots)==str and ',' in nshots:
        split = nshots.split(",")
        if split[-1]=='':
            del split[-1]
        rnshots = [int(n) for n in split]
        for n in rnshots:
            random.seed(42)
            regular_data = load_json(f"Filt_reg_data.json")
            print(len(regular_data['examples']))
            random.shuffle(regular_data['examples'])
            irregular_data = load_json(f"Filt_irreg_data.json")
            print(len(irregular_data['examples']))
            random.shuffle(irregular_data['examples'])
            nshots = n
            reg_in_domain_accs = []
            irreg_in_domain_accs=[]
            out_domain_accs= []
            if n == 0:
                repeats = 1
            else:
                repeats = 5
            for repeat1 in range(repeats):
                
                irreg_output = get_id_open_generations(model, irregular_data, uppercase_context, uppercase_answer)
                reg_output = get_id_open_generations(model, regular_data, uppercase_context, uppercase_answer)
               
                output = irreg_output+reg_output
                total_correct, total_seen = print_acc_and_save(output, f'{dataset_name}_{exp_setting}_{model_name}_{nshots}.json')
                acc = total_correct/total_seen
                reg_in_domain_accs.append(acc)

                random.shuffle(regular_data['examples'])
                random.shuffle(irregular_data['examples'])

            console.print("NSHOTS", n)
            console.print("Regular&Irregular Accuracy", reg_in_domain_accs, 'mean:', np.mean(reg_in_domain_accs))

            
