import json
import torch
import torch.nn as nn
import sys
sys.path.append('..')
from modeling import load_gptj, GPTJWrapper, load_gpt2xl, load_gpt2, load_bloom, load_bloom_petals, GPT2Wrapper, LambdaLayer, BloomIdentityLayer, BloomWrapper, BloomPetalsWrapper
from bigbench_tasks import multiple_choice_query, PromptBuilder
from rich.progress import track
from console import console, timer
import numpy as np
import time
from utils import load_json, get_probs_and_mrrs, from_layer_logits_to_prob_distros #model, logits, answer ; logits
import random

random.seed(42)

class PastTenseBuilder(PromptBuilder):

    def __init__(self, dataset):
        super().__init__(dataset, example_separator=dataset['few_shot_example_separator']) #this is a period+space: '. '

    def __len__(self):
        return len(self.dataset['examples'])

    def build_mc_prompt(self, index, include_answer=False):
        return None

    def build_open_prompt(self, index, include_answer=False):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: a string prompt, the answer as a string
        """
        present_verb, past_verb = self.dataset['examples'][index]['input'], self.dataset['examples'][index]['target']
        inputs = f"""Today I {present_verb}. Yesterday I"""
        prompt = inputs
        if include_answer:
            prompt+=' '+past_verb
        return prompt, ' '+past_verb


def generate_ans(model, prompter, idx):
    #making the prompt
    prompt, gt = prompter.nshot_open_prompt(nshots, idx)
    targets = [gt]#prompter.get_mc_targets(idx)
    gt_idx = targets.index(gt)
    #print(prompt, idx, gt)
    #running prompt thru model
    prompt_ids = model.tokenize(prompt)
    #tokenized_gt = model.tokenize(gt)[0]
    #tokenized_pres = model.tokenize(' '+prompter.dataset['examples'][idx]['input'])[0]
    #if len(tokenized_gt)>1 and len(tokenized_pres)>1 and tokenized_gt[0]==tokenized_pres[0]:
    #    #print(f'{gt} gets tokenized as {tokenized_gt}, which is longer than one token and could confound results')
    #    return None
    logits = model.get_layers(prompt_ids)
    
    probs_results = {}
    rrs_results   = {}

    for i in range(len(targets)):
        tgt = targets[i]
        probs, ffn_rrs = get_probs_and_mrrs(model, logits, tgt)
        #attn_logits = []
        #for j in range(len(model.model.transformer.h)):
        #    attn_logits.append(model.model.activations_['intermediate_residual_'+str(j)])
        #attn_logits = torch.stack(attn_logits).unsqueeze(1)
        #attn_logits = torch.stack(model.layer_decode(attn_logits)).squeeze(-1)

        #attn_rrs = model.rr_per_layer(attn_logits, tgt)
        if gt_idx == i:
            final_rr = ffn_rrs[-1]

        probs_results[i] = probs.tolist()
        rrs_results[i] = ffn_rrs.tolist()

        first_rank1_attn = None
        first_rank1_ffn  = None

        #assert len(attn_rrs) == len(ffn_rrs[1:])
        #try:
        #    first_rank1_attn = attn_rrs.tolist().index(1.)
        #except ValueError as e:
        #    pass
        try:
            first_rank1_ffn = ffn_rrs[1:].tolist().index(1.) #get rid of the embedding table reprs
        except ValueError as e:
            pass

    top10_per_layer = model.topk_per_layer(logits, 10)
    #console.print("Prompt:", prompt, top10_per_layer)
    prompt_results = {'inputs':prompt, 'targets':targets, 'answer':gt, 'answer_idx':gt_idx, 'probs':probs_results, 'rrs':rrs_results, 'top10_per_layer':top10_per_layer}
    return prompt_results, final_rr#first_rank1_attn, first_rank1_ffn

def get_open_generations(model, dataset, debug=False):
    attn_top1s, ffn_top1s = [], []
    all_output = []
    prompter = PastTenseBuilder(dataset)
    total_correct = 0
    with torch.no_grad():
        #model.add_hooks()
        for i in track(range(len(prompter)), description='iterating...'):
            s_time = time.time()
            output = generate_ans(model, prompter, i)
            e_time = time.time()

            if output is None:
                continue
            json_out, final_rr = output #first_rank1_attn, first_rank1_ffn
            if debug:
                console.print(f"Input: {json_out['inputs']}, Prediction: {json_out['top10_per_layer'][-1][0]} GT Answer: {json_out['answer']}")
                console.print(f"In {e_time-s_time} seconds")
            all_output.append(json_out)
            if final_rr == 1:
                total_correct+=1
            console.print(f"Acc so far: {total_correct}/{len(all_output)} = {total_correct/len(all_output)}")
            #attn_top1s.append(first_rank1_attn)
            #ffn_top1s.append(first_rank1_ffn)
            #all_prob_distros.append(prob_distros)

    return all_output, attn_top1s, ffn_top1s


def calc_proportion_attn_before_ffn(attn_top1s, ffn_top1s):
    attns = 0
    ffns  = 0
    total = 0.
    for a, f in zip(attn_top1s, ffn_top1s):
        if f is None and a is None:
            continue

        if f is None and a is not None:
            attns+=1
        elif f is not None and a is None:
            ffns+=1
        elif f<a:
            ffns+=1
        else:
            attns+=1
        total+=1
    return attns, total

def save_output(output, fname):
    with open(fname, 'w') as fp:
        json.dump(output, fp, indent=4)


def rm_ffn_from_model(model, rm_layers_num):
    global starting_num_params
    layer_start = max(0, len(model.transformer.h)-rm_layers_num)
    console.print(f"REMOVING LAYERS STARTING AT {layer_start}")
    for i in range(layer_start,len(model.transformer.h)):
        model.transformer.h[i].mlp = nn.Identity()
    after_params = sum([param.numel() for param in model.parameters() if param.requires_grad])
    console.print(f"Original # of parameters {starting_num_params}. After Rm FFN: {after_params}")
    console.print(f"% params removed: {100*((starting_num_params-after_params)/starting_num_params)}")
    return model

def bloom_ffn_from_model(model, rm_layers_num):
    global starting_num_params
    layer_start = max(0, len(model.transformer.h)-rm_layers_num)
    console.print(f"REMOVING LAYERS STARTING AT {layer_start}")
    for i in range(layer_start,len(model.transformer.h)):
        model.transformer.h[i].mlp = BloomIdentityLayer()#nn.Identity()
    after_params = sum([param.numel() for param in model.parameters() if param.requires_grad])
    console.print(f"Original # of parameters {starting_num_params}. After Rm FFN: {after_params}")
    console.print(f"% params removed: {100*((starting_num_params-after_params)/(.01+starting_num_params))}")
    return model


def count_correct(data):
    correct= 0
    total= len(data)
    wrong = []
    for d in data:
        target_idx = d['answer_idx']
        #console.print(d['answer'], target_idx, d['rrs'], 'keys', d['rrs'].keys())
        ans = d['rrs'][target_idx][-1]
        if ans==1:
            correct+=1
    return correct

if __name__ == "__main__":

    model_name = sys.argv[1]
    nshots = sys.argv[2]
    parallel_r_num = None
    if len(sys.argv)>3:
        parallel_r_num = int(sys.argv[3])
    console.print(model_name, nshots, 'shot(s)')
    timer_task = timer.add_task("Loading model")
    with timer:
        if 'gpt2' in model_name:#model_name == 'gpt2-xl':
            model, tokenizer = load_gpt2(model_name)
            model = GPT2Wrapper(model, tokenizer)#GPTJWrapper(gptj, tokenizer)
        elif 'gptj' == model_name:
            model, tokenizer = load_gptj()
            model = GPTJWrapper(model, tokenizer)
        elif 'bloom' == model_name:
            model, tokenizer = load_bloom('bigscience/bloom')
            model = BloomWrapper(model, tokenizer)
        elif 'bloom-petals' == model_name:
            model, tokenizer = load_bloom_petals("bigscience/bloom-petals")
            model = BloomPetalsWrapper(model, tokenizer)
    timer.stop_task(timer_task)
    timer.update(timer_task, visible=False)

    starting_num_params = sum([param.numel() for param in model.model.parameters() if param.requires_grad])
 
    regular_data = load_json("past_tense_regular.json")
    random.shuffle(regular_data['examples'])
    irregular_data = load_json("past_tense_irregular.json")
    random.shuffle(irregular_data['examples'])

    #get_mc_generations(model, dataset)
    num_layers=len(model.model.transformer.h)
    rrange=list(range(0,num_layers,num_layers//6))
    rrange.append(num_layers)
    if model_name == 'gptj':
        rrange = [0,5,10,14,18,23,28]
    if 'bloom' in model_name:
        rrange = [0,12,24,36,48,60,70]

    if parallel_r_num is not None:
        rrange = [parallel_r_num] #use only the specified number
    console.print(f"RRANGE: {rrange} .. Num Layers:", num_layers)
    regular_accs = []
    irregular_accs=[]
    for r in rrange:#[12,24,36,48,60,70]:
        if r>0:
            if 'bloom' in model_name:
                model.model = bloom_ffn_from_model(model.model, r)
            else:
                model.model = rm_ffn_from_model(model.model, r)

        if not torch.cuda.is_available():
            model.model = model.model.float()
        if type(nshots)==str and ',' in nshots:
            rnshots = [int(n) for n in nshots.split(',')]
            for n in rnshots:
                console.print("NSHOTS", n)
                nshots = n
                console.print("REGULAR")
                output, attn_top1s, ffn_top1s = get_open_generations(model, regular_data)
                total_seen = len(output)
                attns_first, total_correct = calc_proportion_attn_before_ffn(attn_top1s, ffn_top1s)
                total_correct = count_correct(output)
                console.print(f"attns_first/total_correct = {attns_first/max(1,total_correct)}\n ffns_first/total_correct={(total_correct-attns_first)/total_correct}")
                console.print(f"Total correct/total seen: {total_correct}/{total_seen} = {total_correct / total_seen}")
                save_output(output, f'past_tense_regular_{model_name}_{nshots}_rm_{r}_open_results.json')

                console.print("IRREGULAR")
                output, attn_top1s, ffn_top1s = get_open_generations(model, irregular_data)
                total_seen = len(output)
                attns_first, total_correct = calc_proportion_attn_before_ffn(attn_top1s, ffn_top1s)
                total_correct = count_correct(output)
                console.print(f"attns_first/total_correct = {attns_first/max(1,total_correct)}\n ffns_first/total_correct={(total_correct-attns_first)/total_correct}")
                console.print(f"Total correct/total seen: {total_correct}/{total_seen} = {total_correct/total_seen}")
                save_output(output, f'past_tense_irregular_{model_name}_{nshots}_rm_{r}_open_results.json')
        else:
            console.print("IRREGULAR")
            output, attn_top1s, ffn_top1s = get_open_generations(model, irregular_data, debug=True)
            total_seen = len(output)
            attns_first, total_correct = calc_proportion_attn_before_ffn(attn_top1s, ffn_top1s)
            total_correct = count_correct(output)
            console.print(f"attns_first/total_correct = {attns_first/max(1,total_correct)}\n ffns_first/total_correct={(total_correct-attns_first)/max(1,total_correct)}")
            console.print(f"Total correct/total seen: {total_correct}/{total_seen} = {total_correct / total_seen}")
            save_output(output, f'past_tense_irregular_{model_name}_{nshots}_rm_{r}_open_results.json')
            irregular_accs.append(total_correct / total_seen)

            console.print("REGULAR")
            nshots = int(nshots)
            output, attn_top1s, ffn_top1s = get_open_generations(model, regular_data)
            total_seen = len(output)
            attns_first, total_correct = calc_proportion_attn_before_ffn(attn_top1s, ffn_top1s)
            total_correct = count_correct(output)
            console.print(attns_first)
            console.print(f"attns_first/total_correct = {attns_first/max(1,total_correct)}\n ffns_first/total_correct={(total_correct-attns_first)/max(1,total_correct)}")
            console.print(f"Total correct/total seen: {total_correct}/{total_seen} = {total_correct / total_seen}")
            save_output(output, f'past_tense_regular_{model_name}_{nshots}_rm_{r}_open_results.json')
            regular_accs.append(total_correct / total_seen)

        console.print("REGULAR:", regular_accs)
        console.print("IRREGULAR:", irregular_accs)

