import json
import sys
sys.path.append('..')
from modeling import load_gptj, GPTJWrapper, load_gpt2xl, load_gpt2, GPT2Wrapper
import numpy as np
import torch
from utils import get_probs_and_mrrs, from_layer_logits_to_prob_distros #model, logits, answer ; logits
import random

INP="""Q: On the nightstand, you see a mauve cat toy, a teal puzzle, a black jug, a magenta booklet, a yellow pen, and
a red scrunchiephone charger. What color is the cat toy?
A: Mauve
Q: On the table, there is a red dog leash, a brown teddy bear, a silver pencil, and a teal paperclip. What color is
the paperclip?
A:"""


def tokenize(inp):
    input_ids = tokenizer.encode(inp, return_tensors='pt')
    toks = [tokenizer.decode(s) for s in input_ids[0]]
    return input_ids, toks

#given some input, get repeated ngrams and look for induction heads

#add forward hook to intervene on attention

#get activations prior to input to feed forward



def add_hooks_pre_ffn(model):
    pass

def reset_mask():
    headmask = torch.ones(24,16)
    return headmask

def mask_attn(inp, headmask):
    masked_out = model(inp, output_attentions=True, output_hidden_states=True, head_mask=headmask)
    mattns = masked_out.attentions
    mlogits = wrapper.layer_decode(masked_out.hidden_states)
    mlogits = torch.stack(mlogits).squeeze(-1)
    rrs = wrapper.rr_per_layer(mlogits, " Brown")
    ans_at = find_most_attentive_heads(mattns, -1, 59)
    ans_maxes = [ans_at[l]['max'] for l in ans_at]
    return masked_out, mattns, mlogits, rrs, ans_at, ans_maxes


def find_most_attentive_heads(attentions, from_idx, to_idx, batch_idx=0, console=None):
    """
    attentions: the attention map output for model(forward, output_attentions=True).attentions. For layer, print every head for each item in the batch
                Has shape (num_layers, batch_number, num_heads, sequence_length, sequence_length)
    """
    attentive_layers = {}
    #each key in the attn_scores dictionary should be a code of the form <layer>.<head>
    for layer in range(len(attentions)):
        attentive_layers[layer]={'all':[], 'avg':0}
        lattn = attentions[layer][batch_idx].detach().cpu().numpy()
        from_to_lattn = lattn[:, from_idx, to_idx] # a 1d list of attentions for each head
        attentive_layers[layer]['avg']=from_to_lattn.mean()
        attentive_layers[layer]['all']=from_to_lattn.tolist()
        attentive_layers[layer]['max']=from_to_lattn.max()
        attentive_layers[layer]['argmax']=from_to_lattn.argmax()
    return attentive_layers

def attn_maxes(attn_scores):
    return [attn_scores[l]['max'] for l in attn_scores]

def attn_argmaxes(attn_scores):
    return [attn_scores[l]['argmax'] for l in attn_scores]

def mask_attn(inp, headmask, from_idx, to_idx, answer):
    wrapper.model.activations_={}
    logits, attns = wrapper.get_layers_w_attns(inp, head_mask=headmask)
    rrs = wrapper.rr_per_layer(logits, answer)#' Brown')
    inter_logits = []
    for i in range(len(wrapper.model.transformer.h)):
        inter_logits.append(wrapper.model.activations_['intermediate_residual_'+str(i)])
    inter_logits = torch.stack(inter_logits).unsqueeze(1)
    inter_logits = torch.stack(wrapper.layer_decode(inter_logits)).squeeze(-1)
    inter_rrs = wrapper.rr_per_layer(inter_logits, answer)#' Brown')
    at =find_most_attentive_heads(attns, from_idx, to_idx)
    maxes = attn_maxes(at)
    argmaxes = attn_argmaxes(at)
    return logits, rrs, inter_logits, inter_rrs, maxes, argmaxes, attns

from modeling import *
INP2 = """Q: On the floor, I see a silver keychain, a red pair of sunglasses, a gold sheet of paper, a black dog leash, and a blue cat toy. What color is the keychain?\nA: Silver\nQ: On the table, you see a brown sheet of paper, a red fidget spinner, a blue pair of sunglasses, a teal dog leash, and a gold cup. What color is the sheet of paper?\nA:"""

INP2sheet = """Q: On the floor, I see a silver keychain, a red pair of sunglasses, a gold sheet of paper, a black dog leash, and a blue cat toy. What color is the keychain?\nA: Silver\nQ: On the table, you see a brown sheet of paper, a red fidget spinner, a blue pair of sunglasses, a teal dog leash, and a gold cup. What color is the sheet"""

INPill = """A group of scientists wanted to know whether spotted rats,
who are pickier eaters than other rats, liked a new kind of food.
They tested white, black, and spotted rats of both sexes.
The scientists discovered that all of the rats loved the food.
Now that they knew that some of the rats loved the food,
they decided to issue a recommendation based on their findings."""
INPill="""A group of scientists wanted to know whether spotted rats,
who are pickier eaters than other rats, liked a new kind of food.
They tested white, black, and spotted rats of both sexes.
The scientists discovered that all of the rats loved the food.
Now that they knew that"""


abs_squad1 = """Q: What is the capital of France?\nA: Paris\nQ: What is the capital of Poland?\nA:"""
ext_squad1 = """ Paris\nQ: What is the capital of France?\nA: Paris\n Warsaw\nQ: What is the capital city of Poland?\nA:"""



model, tokenizer = load_gpt2('gpt2-medium')
model = model.eval().float()
wrapper = GPT2Wrapper(model, tokenizer)
def tok(i):
    inp = wrapper.tokenize(i)
    toks = wrapper.list_decode(inp[0])
    return inp, toks


inp, toks = tok(INP2)
def reset_mask(nheads=16):
    headmask = torch.ones(len(wrapper.model.transformer.h),nheads)
    return headmask

headmask = reset_mask()
wrapper.add_hooks()
logits, rrs, inter_logits, inter_rrs, maxes, argmaxes, attns = mask_attn(inp, headmask, -1, 57, " Brown")

logits = wrapper.get_layers(inp)

outputs = wrapper.model(input_ids=inp, output_hidden_states=True, output_attentions=True)
hids = list(outputs.hidden_states)
mlp_act = wrapper.model.activations_['mlp_19']
inter_hids = []
for i in range(len(wrapper.model.transformer.h)):
    inter_hids.append(wrapper.model.activations_['intermediate_residual_'+str(i)])


inter_hids = torch.stack(inter_hids).unsqueeze(1)
inter_logits = torch.stack(wrapper.layer_decode(inter_hids)).squeeze(-1)

def attn_maxes(attn_scores):
    return [attn_scores[l]['max'] for l in attn_scores]


def rm_ffn(model, num_rm):
    layer_start = len(model.transformer.h)-num_rm
    for i in range(layer_start, len(model.transformer.h)):
        model.transformer.h[i].mlp = nn.Identity()
    return model



capit = torch.tensor(np.load('capitalization_update_gpt2-med.npy'))
outputs = wrapper.model(inp, output_hidden_states=True)
hids = list(outputs.hidden_states)
logits = torch.stack(wrapper.layer_decode(hids)).squeeze(-1)
wrapper.print_top(logits)



norman_0ext = """Q: The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. In what country is Normandy located?\nA:"""
norman_1ext = """The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia.\nQ: In what country is Normandy located?\nA: France\nQ: Who was the Norse leader?\nA:"""

def add_capit(model, num_layers, eps =1):
    layer_start = len(model.transformer.h)-num_layers
    for i in range(layer_start, len(model.transformer.h)):
        model.transformer.h[i].mlp = LambdaLayer(lambda x: x+(eps*capit))
    return model

layer_id = 18
def save_acts(layer_id):
    sln_in = wrapper.model.activations_[f'in_sln_{layer_id}']
    #np.save(f"gpt2-med_{layer_id}_sln_in.npy", sln_in.squeeze().detach().numpy())
    #attn_in = wrapper.model.activations_[f'in_attn_{layer_id}']
    #np.save(f"gpt2-med_{layer_id}_attn_in.npy", attn_in.squeeze().detach().numpy())
    #attn_out = wrapper.model.activations_[f'attn_{layer_id}']
    #np.save(f"gpt2-med_{layer_id}_attn_out.npy", attn_out.squeeze().detach().numpy())
    #iln_in = wrapper.model.activations_[f'intermediate_residual_{layer_id}']
    #np.save(f"gpt2-med_{layer_id}_iln_in.npy", iln_in.squeeze().detach().numpy())
    iln_out = wrapper.model.activations_[f'out_intermediate_residual_{layer_id}']
    np.save(f"gpt2-med_{layer_id}_sheet_iln_out.npy", iln_out.squeeze().detach().numpy())
    #mlp_out = wrapper.model.activations_[f'mlp_{layer_id}']
    #np.save(f"gpt2-med_{layer_id}_mlp_out.npy", mlp_out.squeeze().detach().numpy())

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

if __name__ == '__main__':
    model_name = sys.argv[1]
    model, tokenizer = load_gpt2(model_name)



    '''
    SQUAD NORMAN EXAMPLE:
    The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.
    '''
