
import json
import sys
sys.path.append('..')
from modeling import load_gptj, GPTJWrapper, load_gpt2xl, load_gpt2, GPT2Wrapper, load_bloom, BloomWrapper
from bigbench_tasks import load_bigbench_task, multiple_choice_query, PromptBuilder, DefaultOpenPrompter
from rich.progress import track
from console import console, timer
import numpy as np
from utils import get_probs_and_mrrs #model, logits, answer ; logits
import random
import gc
import torch
import argparse
from torch import nn


random.seed(42)
TOTAL_CORRECT =  0

class ColoredObjectsPrompter(DefaultOpenPrompter):
    def __init__(self, dataset, is_extractive):
        super(ColoredObjectsPrompter, self).__init__(dataset)
        self.is_extractive=is_extractive

    def build_open_prompt(self, index, include_answer=False):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: a string prompt, the answer as a string
        """
        datapoint = self.dataset[index]
        inputs = datapoint['inputs']#, datapoint['multiple_choice_targets']
        if self.is_extractive:
            gt = datapoint['targets'][0].lower() #lowercase the first letter!
        else:
            gt = datapoint['targets'][0].title() #capitalize the first letter!
        #gt_idx = targets.index(gt)
        prompt = inputs
        if include_answer:
            prompt+=' '+gt
        return prompt, ' '+gt

    def get_mc_targets(self, idx):
        if self.is_extractive:
            return [' '+t.lower() for t in self.dataset[idx]['multiple_choice_targets']]
        else:
            return [' '+t.title() for t in self.dataset[idx]['multiple_choice_targets']]

def generate_ans(model, prompter, idx):
    global TOTAL_CORRECT
    #making the prompt
    prompt, gt = prompter.nshot_open_prompt(nshots, idx)
    targets = prompter.get_mc_targets(idx)
    gt_idx = targets.index(gt)
    target_correct=0
    extractive_correct = 0
    abstractive_correct= 0 

    #running prompt thru model
    prompt_ids = model.tokenize(prompt)
    logits = model.get_layers(prompt_ids)

    probs_results = {}
    rrs_results   = {}

    for i in range(len(targets)):
        tgt = targets[i]
        probs, rrs = get_probs_and_mrrs(model, logits, tgt)
        probs_results[i] = probs.tolist()
        rrs_results[i] = rrs.tolist()
        if i == gt_idx:
            #print(targets[i], rrs_results[i][-1])
            if rrs_results[i][-1] ==1:
                target_correct+=1
                #console.print(targets[i], "ACC SO FAR:", TOTAL_CORRECT/(idx+1.))
            if prompter.is_extractive:
                #console.print("ext target", tgt, "abs target", ' '+tgt.strip().title())
                abs_probs, abs_rrs = get_probs_and_mrrs(model, logits, ' '+tgt.strip().title())
                if abs_rrs.tolist()[-1]==1:
                    abstractive_correct+=1
            else:
                #console.print("abs target", tgt, "ext target", tgt.lower())
                ext_probs, ext_rrs = get_probs_and_mrrs(model, logits, tgt.lower())
                if ext_rrs.tolist()[-1]==1:
                    extractive_correct+=1

    top10_per_layer = model.topk_per_layer(logits, 10)
    prompt_results = {'inputs':prompt, 'targets':targets, 'answer':gt, 'answer_idx':gt_idx, 'probs':probs_results, 'rrs':rrs_results, 'top10_per_layer':top10_per_layer}
    #Lastly, turn the logits into a bunch of prob distributions over the whole vocab. This will be stored separately
    if prompter.is_extractive:
        other_correct = abstractive_correct
    else:
        other_correct = extractive_correct
    return prompt_results, target_correct, other_correct

def get_open_generations(model, dataset, removed_layers):
    global TOTAL_CORRECT, extractive, eps
    TOTAL_CORRECT =  0
    target_correct = 0
    other_correct = 0
    output = []
    prompter = ColoredObjectsPrompter(dataset, extractive)
    with torch.no_grad():
        for i in track(range(len(prompter)), description='iterating...'):
            json_out, tgt_cor, oth_cor = generate_ans(model, prompter, i)
            output.append(json_out)
            target_correct+=tgt_cor
            other_correct+=oth_cor
    return output, target_correct, other_correct


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)


def add_capit_layers(model, rm_layers_num, eps=1):
    global capit
    b4_params = sum([param.numel() for param in model.parameters() if param.requires_grad])
    layer_start = max(0, len(model.transformer.h)-rm_layers_num)
    console.print(f"REMOVING LAYERS STARTING AT {layer_start}")
    for i in range(layer_start,len(model.transformer.h)):
        model.transformer.h[i].mlp = LambdaLayer(lambda x: (eps*capit))
    after_params = sum([param.numel() for param in model.parameters() if param.requires_grad])
    console.print(f"Original # of parameters {b4_params}. After Rm FFN: {after_params}")
    console.print(f"% params removed: {100*((b4_params-after_params)/b4_params)}")
    return model


def save_output(output, fname):
    
    with open(fname, 'w') as fp:
        json.dump(output, fp, indent=4)

def create_output_name(model_name, nshots, is_extractive, intervention):
    global eps
    abs_or_ext = 'ext' if is_extractive else 'abs'
    is_interv = 'interv' if intervention else 'cntrl'
    fname = f"{model_name}_{nshots}_open_cobjs_{abs_or_ext}_{is_interv}_{str(eps)}.json"
    return fname

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model_name", help="name of model to be used")
    parser.add_argument("nshots", type=str, help="number of shots to be used")
    parser.add_argument("--is_extractive", action='store_true', help="whether or not the task is extractive")
    parser.add_argument("--intervention", action='store_true', help='Whether to intervene on 19-24 of gpt2 medium')
    parser.add_argument("--eps", default=1, type=float, help='The scaling factor for capit')
    args = parser.parse_args()
    eps=args.eps
    model_name = args.model_name#sys.argv[1]
    nshots = args.nshots#sys.argv[2]
    extractive = args.is_extractive
    intervention = args.intervention
    console.print('extractive?', extractive)
    console.print(model_name, nshots, 'shot(s)')
    console.print(args)
    dataset = load_bigbench_task('reasoning_about_colored_objects')['default']
    colors_dataset = [dataset[i] for i in range(200)]#dataset = dataset[:200]

    capit = torch.tensor(np.load('../capitalization_update_gpt2-med.npy'))
    if torch.cuda.is_available():
        capit = capit.cuda().half()

    timer_task = timer.add_task("Loading model")
    with timer:
        if 'gpt2' in model_name:#model_name == 'gpt2-xl':
            model, tokenizer = load_gpt2(model_name)
            model = GPT2Wrapper(model, tokenizer)#GPTJWrapper(gptj, tokenizer)
            if model_name == 'gpt2-medium':
                rrange = [0,4,6,12,16,20,24]
                if intervention:
                    model.model = add_capit_layers(model.model, 5, eps=eps)
        elif 'gptj' == model_name:
            model, tokenizer = load_gptj()
            model = GPTJWrapper(model, tokenizer)
            rrange = [0,5,10,15,20,25,28]
        elif 'bloom' in model_name:
            model, tokenizer = load_bloom(model_name)
            model = BloomWrapper(model, tokenizer)
            capit = capit.bfloat16()
            rrange = [0,12,24,36,48,60,70]

    
    timer.stop_task(timer_task)
    timer.update(timer_task, visible=False)

    accuracies = []
    console.print(len(model.model.transformer.h), "LAYERS IN THE MODEL")
    fname = create_output_name(model_name, nshots, extractive, intervention)
    if extractive:
        other_acc_type='Abstractive'
    else:
        other_acc_type = 'Extractive'
    for r in [0]:#rrange:
        if r>0:
            model.model = add_capit_layers(model.model, r, eps=eps)

        if not torch.cuda.is_available():
            model.model = model.model.float()
     
        #get_mc_generations(model, dataset)
        dataset = list(colors_dataset)
        if type(nshots)==str and ',' in nshots:
            rnshots = [int(n) for n in nshots.split(',')]
            for n in rnshots:
                console.print("NSHOTS", n)
                nshots = n
                output, total_correct, other_correct =get_open_generations(model, dataset, 0)
                save_output(output, fname)
                console.print("ACCURACY", (total_correct/200)*100)
                accuracies.append((total_correct/200)*100)
                
        else:
            nshots = int(nshots)
            if nshots>0:
                repeats = 5
            else:
                repeats = 1
            run_accs = []
            run_other_accs = []
            for repeat in range(repeats):
                output, total_correct, other_correct = get_open_generations(model, dataset, r)
                save_output(output, fname)
                console.print(f"{r} layers removed ACCURACY", (total_correct/200)*100)
                run_accs.append((total_correct/200)*100)
                run_other_accs.append((other_correct/200)*100)
                random.shuffle(dataset)
            console.print("ACCURACY", run_accs, 'MEAN:', np.mean(run_accs))
            console.print(other_acc_type, "ACCURACY", run_other_accs, 'MEAN:', np.mean(run_other_accs))
    console.print("ACCURACIES", accuracies)
