import re
import json
import sys
sys.path.append('..')
from modeling import load_gptj, GPTJWrapper, load_gpt2xl, load_gpt2, GPT2Wrapper
#from bigbench_tasks import load_bigbench_task, multiple_choice_query, PromptBuilder, DefaultOpenPrompter
from bigbench_tasks import PromptBuilder, DefaultOpenPrompter, load_bigbench_from_results
from rich.progress import track
import rich
from console import console, timer
import numpy as np
from utils import get_probs_and_mrrs, from_layer_logits_to_prob_distros #model, logits, answer ; logits
import random
from torch import nn
import torch

random.seed(42)
TOTAL_CORRECT = 0
LAYER_START = 4
class ColoredObjectsPrompter(DefaultOpenPrompter):
    def __init__(self, dataset, is_extractive):
        super(ColoredObjectsPrompter, self).__init__(dataset)
        self.is_extractive = is_extractive

    def build_open_prompt(self, index, include_answer=False):
        """
        nshots: int number of shots. 0 means just return dataset[index] prompt. 1 means 1 example and 1 test prompt
        dataset a bigbench dataset that can be indexed by index. dataset[index]
        index: the index of the test example. This method will use the [index-nshots, index) datapoints as the nshot examples, wrapping if necessary
        returns: a string prompt, the answer as a string
        """
        datapoint = self.dataset[index]
        rawinputs = datapoint['inputs']#, datapoint['multiple_choice_targets']
        #print(raw)
        inputs = re.split(r'\nA:.*?\n', rawinputs)
        targets = re.findall(r'\nA:.*?\n', rawinputs)
        for i in range(len(targets)):
            #targets[i] = targets[i].strip("\n").strip("\A:")
            if self.is_extractive:
                tgt = targets[i]
                inputs[i]+=tgt[:2]+tgt[2:].lower()
            else:
                inputs[i]+=targets[i]
        inputs = ''.join(inputs)
        
        if not self.is_extractive:
            gt = datapoint['answer'] #captialize the first letter!
        else:
            gt = datapoint['answer'].lower() #lowercase
        #gt_idx = targets.index(gt)
        prompt = inputs
        if include_answer:
            prompt+=gt
        return prompt, gt

    def get_mc_targets(self, idx):
        if self.is_extractive:
            return [t.lower() for t in self.dataset[idx]['targets']]
        else:
            return [t for t in self.dataset[idx]['targets']]

    def __len__(self):
        return  200

def generate_ans(model, prompter, idx):
    global TOTAL_CORRECT
    #making the prompt
    prompt, gt = prompter.nshot_open_prompt(0, idx)
    targets = prompter.get_mc_targets(idx)
    gt_idx = targets.index(gt)

    #running prompt thru model
    prompt_ids = model.tokenize(prompt)
    logits = model.get_layers(prompt_ids)
    print(prompt)
    probs_results = {}
    rrs_results   = {}
    print(gt_idx, targets)
    for i in range(len(targets)):
        tgt = targets[i]
        probs, rrs = get_probs_and_mrrs(model, logits, tgt)
        probs_results[i] = probs.tolist()
        rrs_results[i] = rrs.tolist()
        if i == gt_idx:
            print(targets[i], rrs_results[i][-1])
            if rrs_results[i][-1] ==1:
                TOTAL_CORRECT+=1
                console.print(targets[i], "ACC SO FAR:", TOTAL_CORRECT/(idx+1.))
    #rich.print(rrs_results[gt_idx])
    top10_per_layer = model.topk_per_layer(logits, 10)
    #rich.print(top10_per_layer[-1])
    prompt_results = {'inputs':prompt, 'targets':targets, 'answer':gt, 'answer_idx':gt_idx, 'probs':probs_results, 'rrs':rrs_results, 'top10_per_layer':top10_per_layer}
    #Lastly, turn the logits into a bunch of prob distributions over the whole vocab. This will be stored separately
    #prob_distros = from_layer_logits_to_prob_distros(logits)
    return prompt_results, []

def get_open_generations(model, dataset):
    global TOTAL_CORRECT, extractive, rm_ffn
    TOTAL_CORRECT =  0
    output = []
    prompter = ColoredObjectsPrompter(dataset, extractive)
    all_prob_distros = []
    print("# Prompts ",len(prompter))
    for i in track(range(1,len(prompter)), description='iterating'):
        json_out, prob_distros = generate_ans(model, prompter, i)
        output.append(json_out)
        all_prob_distros.append(prob_distros)

    if extractive and rm_ffn:
        with open(f'{model_name}_{nshots}_open_cobjs_ext_rm_ffns_{LAYER_START}.json', 'w') as fp:
            json.dump(output, fp, indent=4)
    elif extractive and not rm_ffn:
        with open(f'{model_name}_{nshots}_open_cobjs_ext_full_{LAYER_START}.json', 'w') as fp:
            json.dump(output, fp, indent=4)
    elif not extractive and rm_ffn:
        with open(f'{model_name}_{nshots}_open_cobjs_rm_ffns_{LAYER_START}.json', 'w') as fp:
            json.dump(output, fp, indent=4)

    #all_prob_distros = np.stack(all_prob_distros)
    #print(all_prob_distros.shape)
    #np.save(f'{model_name}_{nshots}_open_cobjs_vocab_distros.npy', all_prob_distros)


def rm_ffn(model):
    b4_params = sum([param.numel() for param in model.parameters() if param.requires_grad])
    console.print(f"STAETING AT {LAYER_START}")
    for i in range(LAYER_START,len(model.transformer.h)):
        model.transformer.h[i].mlp = nn.Identity()
    after_params = sum([param.numel() for param in model.parameters() if param.requires_grad])
    console.print(f"Original # of parameters {b4_params}. After Rm FFN: {after_params}")
    return model

if __name__ == "__main__":

    model_name = sys.argv[1]
    nshots = sys.argv[2]
    remove_ffn = sys.argv[3]
    extractive = sys.argv[4]
    if remove_ffn.lower() != 'no_ffn':
        remove_ffn = False
    else:
        remove_ffn = True

    if extractive == "extractive":
        extractive = True
    else:
        extractive = False

    console.print(f"Removing FFNs? {remove_ffn} ... Is Extractive? {extractive}")
    console.print(model_name, nshots, 'shot(s)')
    timer_task = timer.add_task("Loading model")
    with timer:
        if 'gpt2' in model_name:#model_name == 'gpt2-xl':
            model, tokenizer = load_gpt2(model_name)
            if remove_ffn:
                model = rm_ffn(model)
            model = GPT2Wrapper(model, tokenizer)#GPTJWrapper(gptj, tokenizer)
        elif 'gptj' == model_name:
            model, tokenizer = load_gptj()
            if remove_ffn:
                model = rm_ffn(model)
            model = GPTJWrapper(model, tokenizer)
    timer.stop_task(timer_task)
    timer.update(timer_task, visible=False)

    if not torch.cuda.is_available():
        model.model = model.model.float()

    #get_mc_generations(model, dataset)
    if ',' in nshots:
        rnshots = [int(n) for n in nshots.split(',')]
        for n in rnshots:
            dataset = load_bigbench_from_results(f"{model_name}_{n}_open_cobjs_results.json")#load_bigbench_task('reasoning_about_colored_objects')['default'][:200] #The first 200 only have extractive questions
            dataset = dataset[:200]
            console.print("NSHOTS", n)

            get_open_generations(model, dataset)
            console.print("ACCURACY", TOTAL_CORRECT/200)
    else:
        nshots = int(nshots)
        dataset = load_bigbench_from_results(f"{model_name}_{nshots}_open_cobjs_results.json")
        dataset = dataset[:200]
        get_open_generations(model, dataset)
        console.print("ACCURACY", TOTAL_CORRECT/200)
