# Run input(s) at certain position(s) through component
# and subtract from (ie add -1 *) the output of the head
from imports import *
from utils import load_json
from torch.utils.data import DataLoader, Dataset

from tango import step
import time
import tango
from tango.common import FromParams
import termplotlib as tpl
from weights_composer import re_get_single_component, get_ov, remove_components
import sys
from ioi_inhibition_exp import (
    PromptDataset, 
    load_dataset, 
    load_model, 
    DataParams, 
    ModelParams, 
    calc_inhib_score,
    calc_inhib_score_to_all_S,#(model, prompt, cache, mover_layer, mover_head)
    calc_inhib_score_to_S2,
    get_inhibition_scores,
    get_token_idx
)
from ioi_inhib_add_comp_patch_exp import calc_component_resid, attn_result_hook
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name
from my_plotly import *
import plotly.graph_objects as go
from fancy_einsum import einsum

from ioi_inhib_add_comp_patch_exp import get_all_token_idx_occurrences,  get_attn_score_from_to

import rich


def patch_head_pattern(
    head_pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
    hook,
    head_idx,
    query_idxs,
    key_idxs,
):
    head_pattern[:, head_idx, query_idxs, key_idxs] = 1.
    return head_pattern

def add_mult_of_comp_out_at_idxs_hook(
        hook_vals: Float[torch.Tensor, "batch pos head_index d_model"],
        hook: HookPoint,
        head_idx: int,
        pos_idxs: list, #list of ints
        comp: Float[torch.Tensor, 'd_model d_model'],
        scale: float,
        resid_pre_cache: Float[torch.Tensor, "batch pos d_model"],
    ) -> Float[torch.Tensor, "batch pos head_index d_model"]:
    
    selected_res_pre = resid_pre_cache[range(len(hook_vals)), pos_idxs]
    comp_out = torch.matmul(selected_res_pre, comp)
    hook_vals[range(len(hook_vals)), pos_idxs, head_idx] += scale*comp_out
    return hook_vals

def add_mult_of_comp_out_to_resid_at_idxs_hook(
        hook_vals: Float[torch.Tensor, "batch pos d_model"],
        hook: HookPoint,
        pos_idxs: list, #list of ints
        comp: Float[torch.Tensor, 'd_model d_model'],
        scale: float,
        resid_pre_cache: Float[torch.Tensor, "batch pos d_model"],
    ) -> Float[torch.Tensor, "batch pos d_model"]:
    
    selected_res_pre = resid_pre_cache[range(len(hook_vals)), pos_idxs]
    comp_out = torch.matmul(selected_res_pre, comp)
    hook_vals[range(len(hook_vals)), pos_idxs] += scale*comp_out
    return hook_vals

@step(cacheable=True, deterministic=True, version='001')
def add_mult_of_comp_out_at_idxs(  
    model: FromParams,
    dataset: DataParams,
    interv_layer: int,
    interv_head: int,
    comp_idx,
    token_type: str, #S1, S2, IO
    target_layer: int,
    target_head: int,
    mover_layer: int, 
    mover_head: int,
    scales:list,
    with_C=False,
    raw_scores=False) -> list:

    model=model.model
    model.set_use_attn_result(True)
    dataset = dataset.dataset

    comps = []

    ov = get_ov(model, interv_layer, interv_head)
    u, s, v = ov.svd()
    if type(comp_idx) == int:
        comp_idx = [comp_idx]
    
    comp = torch.zeros_like(ov.AB)
    for idx in comp_idx:
        comp += re_get_single_component(u, s, v, idx).AB


    s1_first_attn = []
    s1_2nd_attn = []
    s2_attn = []
    io_first_attn = []
    io_2nd_attn = []
    c1_attn = []
    c2_attn = []
    #dataset = load_dataset(path=dataset_path, batch_size=20)
    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        for key in prompts:
            newprompt[key] = prompts[key][idx]
        return newprompt

    for scale in scales:
        for batch in track(dataset):
            text = batch['text']
            subjs = batch['S']
            ios = batch['IO']
            #print("TEXT", text)
            model.reset_hooks()
            _, cache = model.run_with_cache(text, prepend_bos=True)
            if comp_idx != None:
                #print(batch)
                tokenized_text = model.to_str_tokens(text, prepend_bos=True)
                end_idxs = []
                pos_idxs = []
                for i,prompt in enumerate(text):
                    prompt_tokens = model.to_str_tokens(prompt, prepend_bos=True)
                    #print(prompt_tokens[-1])
                    end_idxs.append(len(prompt_tokens)-1)
                    cur_s = ' '+subjs[i]
                    cur_io = ' '+ios[i]
                    if token_type == 'S1':
                        s1_idx, s2_idx = get_all_token_idx_occurrences(prompt_tokens, cur_s)[:2]
                        pos_idxs.append(s1_idx)
                    elif token_type == 'S2':
                        s1_idx, s2_idx = get_all_token_idx_occurrences(prompt_tokens, cur_s)[:2]
                        pos_idxs.append(s2_idx)
                    elif token_type == 'IO':
                        io_idx = get_token_idx(prompt_tokens, cur_io)
                        pos_idxs.append(io_idx)
                    
                #print("REsid pre", cache['resid_pre', interv_layer].shape)

                #values_to_add = vec*scale
                hook_fn = partial(
                    add_mult_of_comp_out_at_idxs_hook, 
                    pos_idxs=pos_idxs, 
                    head_idx=target_head, 
                    comp=comp, 
                    scale=scale, 
                    resid_pre_cache=cache['resid_pre', interv_layer]
                )
                model.blocks[target_layer].attn.hook_result.add_hook(hook_fn)
                _, cache = model.run_with_cache(text, prepend_bos=True)
                model.reset_hooks()

            for batch_idx in range(len(text)):
                cur_prompt = get_prompt(batch, batch_idx)
                s1_token= ' '+cur_prompt['S']
                io_token = ' '+cur_prompt['IO']
                prompt_tokens = model.to_str_tokens(cur_prompt['text'])
                #print("PROMPT TOKENS", [f'{tokidx}_{s}' for tokidx, s, in enumerate(prompt_tokens)])
                
                end_idx = len(prompt_tokens)-1
                s1_idx, s2_idx = get_all_token_idx_occurrences(prompt_tokens, s1_token)[:2]
                io_idx = get_token_idx(prompt_tokens, io_token)
                #print('s1',  s1_idx, s1_token, 'io', io_idx, io_token)
                

                #s1_attn.append(cache[get_act_name('attn', inhib_layer)][batch_idx, s1_idx, mover_head].item())
                s1score = get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, s1_idx, batch_idx, raw_scores=raw_scores)
                s2_attn.append(get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, s2_idx, batch_idx, raw_scores=raw_scores))
                ioscore = get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, io_idx, batch_idx, raw_scores=raw_scores)


                if s1_idx < io_idx:
                    s1_first_attn.append(s1score)
                    io_2nd_attn.append(ioscore)
                else:
                    s1_2nd_attn.append(s1score)
                    io_first_attn.append(ioscore)
                    
                if with_C:
                    c_token = ' '+cur_prompt['C']
                    c1_idx, c2_idx = get_all_token_idx_occurrences(prompt_tokens, c_token)[:2]
                    c1_attn.append(get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, c1_idx, batch_idx, raw_scores=raw_scores))
                    c2_attn.append(get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, c2_idx, batch_idx, raw_scores=raw_scores))

    if not with_C:
        c1_attn = [0.]
        c2_attn = [0.]
    return [np.stack([s1_first_attn, s1_2nd_attn]), np.stack([s2_attn]), np.stack([io_first_attn, io_2nd_attn]), np.array(c1_attn), np.array(c2_attn)]


@step(cacheable=True, deterministic=True, version='017')
def set_result_to_mult_of_comp_out_at_idxs(
    model: FromParams,
    dataset: DataParams,
    interv_layer: int,
    interv_head: int,
    comp_idx,
    modifiers: list,
    token_type: str, #S1, S2, IO
    target_layers: int,
    target_heads: int,
    target_comp_idxs, #can be None
    mover_layer: int, 
    mover_head: int,
    scales:list,
    with_C=False,
    raw_scores=False) -> list:
    #Very similar to above experiment but with setting the result of the attn head output instead of adding information to it
    #This experiment is like a combination of the above (add_mult_of_comp_out_at_idxs), and add_in_multiple_component_patches from ioi_inhib_add_comp_patch_exp

    model=model.model
    model.set_use_attn_result(True)
    dataset = dataset.dataset

    comps = []

    ov = get_ov(model, interv_layer, interv_head)
    u, s, v = ov.svd()
    if type(comp_idx) == int:
        comp_idx = [comp_idx]
    
    comp = torch.zeros_like(ov.AB)
    if modifiers == None:
        modifiers = [1. for _ in comp_idx]
    if comp_idx != None:
        for i,idx in enumerate(comp_idx):
            comp += re_get_single_component(u, s, v, idx).AB*modifiers[i]
    else:
        comp = ov.AB

    target_comps = []
    for target_layer, target_head, tgt_comp_idx in zip(target_layers, target_heads, target_comp_idxs):
        #print(target_layer, target_head, tgt_comp_idx)
        ov = get_ov(model, target_layer, target_head)
        if comp_idx == None:
            print("USING FULL TARGET OV AS TARGET COMP")
            target_comps.append(ov.AB)
        else:
            u, s, v = ov.svd()
            target_comps.append(re_get_single_component(u, s, v, tgt_comp_idx).AB)
        
    s1_first_attn = []
    s1_2nd_attn = []
    s2_attn = []
    io_first_attn = []
    io_2nd_attn = []
    c1_attn = []
    c2_attn = []
    #dataset = load_dataset(path=dataset_path, batch_size=20)
    def get_prompt(prompts, idx):
        newprompt = dict.fromkeys(prompts)
        for key in prompts:
            newprompt[key] = prompts[key][idx]
        return newprompt


    for scale in scales:
        inhib_scores = []
        for batch_idx, batch in track(enumerate(dataset)):
            text = batch['text']
            subjs = batch['S']
            ios = batch['IO']
            #print("TEXT", text)
            #model.reset_hooks()
            _, cache = model.run_with_cache(text, prepend_bos=True)
            if comp_idx != None:
                #print(batch)
                tokenized_text = model.to_str_tokens(text, prepend_bos=True)
                end_idxs = []
                pos_idxs = []
                for i,prompt in enumerate(text):
                    prompt_tokens = model.to_str_tokens(prompt, prepend_bos=True)
                    #print(prompt_tokens[-1])
                    end_idxs.append(len(prompt_tokens)-1)
                    cur_s = ' '+subjs[i]
                    cur_io = ' '+ios[i]
                    if token_type == 'S1':
                        s1_idx, s2_idx = get_all_token_idx_occurrences(prompt_tokens, cur_s)[:2]
                        pos_idxs.append(s1_idx)
                    elif token_type == 'S2':
                        s1_idx, s2_idx = get_all_token_idx_occurrences(prompt_tokens, cur_s)[:2]
                        pos_idxs.append(s2_idx)
                    elif token_type == 'IO':
                        io_idx = get_token_idx(prompt_tokens, cur_io)
                        pos_idxs.append(io_idx)
                #print("REsid pre", cache['resid_pre', interv_layer].shape)

                #values_to_add = vec*scale
                #add the component output to the head output. E.g., duplicate token head output along some components
                #to the output of that head
                hook_fn = partial(
                        add_mult_of_comp_out_to_resid_at_idxs_hook, 
                        pos_idxs=pos_idxs, 
                        comp=comp, 
                        scale=scale, 
                        resid_pre_cache=cache['resid_pre', interv_layer]
                )
                model.blocks[target_layer].hook_resid_pre.add_hook(hook_fn)#.attn.hook_result.add_hook(hook_fn)
                
                
                #set the inhibition head (e.g.) to attend to some token and output the component * that value vector
                #this simulates the OV matrix just being that component and just attending to that token.
                for target_layer, target_head, tgt_comp in zip(target_layers, target_heads, target_comps):

                    values_to_set = calc_component_resid(cache[get_act_name('resid_pre', target_layer)], pos_idxs, tgt_comp)
                    hook_fn = partial(attn_result_hook, pos_idxs=end_idxs, head_idx=target_head, new_result_vecs=values_to_set)
                    model.blocks[target_layer].attn.hook_result.add_hook(hook_fn)
                
                _, cache = model.run_with_cache(text, prepend_bos=True)
                model.reset_hooks()

            for batch_idx2 in range(len(text)):
                cur_prompt = get_prompt(batch, batch_idx2)
                #get the default inhibition score (S1 attention only)
                score = calc_inhib_score(model, cur_prompt, cache.apply_slice_to_batch_dim(batch_idx2), mover_layer, mover_head)
                inhib_scores.append(score.item())

            for batch_idx in range(len(text)):
                cur_prompt = get_prompt(batch, batch_idx)


                s1_token= ' '+cur_prompt['S']
                io_token = ' '+cur_prompt['IO']
                prompt_tokens = model.to_str_tokens(cur_prompt['text'])

                #score = calc_inhib_score(model, cur_prompt, cache.apply_slice_to_batch_dim(batch_idx), mover_layer, mover_head)
                #inhib_scores.append(score)
                #print("PROMPT TOKENS", [f'{tokidx}_{s}' for tokidx, s, in enumerate(prompt_tokens)])
                
                end_idx = len(prompt_tokens)-1
                s1_idx, s2_idx = get_all_token_idx_occurrences(prompt_tokens, s1_token)[:2]
                io_idx = get_token_idx(prompt_tokens, io_token)
                #print('s1',  s1_idx, s1_token, 'io', io_idx, io_token)
                

                #s1_attn.append(cache[get_act_name('attn', inhib_layer)][batch_idx, s1_idx, mover_head].item())
                s1score = get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, s1_idx, batch_idx, raw_scores=raw_scores)
                s2_attn.append(get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, s2_idx, batch_idx, raw_scores=raw_scores))
                ioscore = get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, io_idx, batch_idx, raw_scores=raw_scores)


                if s1_idx < io_idx:
                    s1_first_attn.append(s1score)
                    io_2nd_attn.append(ioscore)
                else:
                    s1_2nd_attn.append(s1score)
                    io_first_attn.append(ioscore)
                    
                if with_C:
                    c_token = ' '+cur_prompt['C']
                    c1_idx, c2_idx = get_all_token_idx_occurrences(prompt_tokens, c_token)[:2]
                    c1_attn.append(get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, c1_idx, batch_idx, raw_scores=raw_scores))
                    c2_attn.append(get_attn_score_from_to(cache, mover_layer, mover_head, end_idx, c2_idx, batch_idx, raw_scores=raw_scores))
        rich.print("Inhibition Scores:", np.mean(inhib_scores))
    if not with_C:
        c1_attn = [0.]
        c2_attn = [0.]
    return [inhib_scores, np.stack([s1_first_attn, s1_2nd_attn]), np.stack([s2_attn]), np.stack([io_first_attn, io_2nd_attn]), np.array(c1_attn), np.array(c2_attn)]



if __name__ == "__main__":

    print("MAIN", sys.argv[0])
    ws = tango.Workspace.from_url("./tango_workspace")
    model_name = 'gpt2-small'
    mname = model_name.split('/')[-1]
    dataset_path = 'datasets/ioi_dataset_200.json' #'datasets/pythia_ioi_dataset_200.json'#
    model_params = ModelParams(model_name)
    #data_params = FillContextDataParams(dataset_path, model_params, batch_size=4, prompt_idx=1, max_ctx_len=1023)
    #data_params.__class__.__module__="ioi_inhib_positionwise_exp"
    data_params = DataParams(dataset_path)
    from tango.common import det_hash
    print(det_hash(data_params))
    print(det_hash(model_params))

    #scale = 1
    #rich.print('SCALE:', scale)

    # s1_first_attn, s2_attn, io_1st_io_2nd_attn, c1_attn, c2_attn = add_mult_of_comp_out_at_idxs(  
    #     model = model_params,
    #     dataset=data_params,
    #     interv_layer= 3,
    #     interv_head= 0,
    #     comp_idx = [1,2],
    #     token_type='S2', #S1, S2, IO
    #     target_layer= 7,
    #     target_head= 9,
    #     mover_layer= 9, 
    #     mover_head = 9,
    #     scales = [scale],
    #     with_C=False,
    #     raw_scores=False
    # ).result(ws)

    for scale in range(-200,201, 1):
        scale = float(scale)
        if scale == 0.:
            rich.print("DEFAULT")
        inhib_scores, s1_first_attn, s2_attn, io_1st_io_2nd_attn, c1_attn, c2_attn  = set_result_to_mult_of_comp_out_at_idxs(
            model = model_params,
            dataset=data_params,
            interv_layer= 3,
            interv_head= 0,
            comp_idx = [0, 3],
            modifiers = None,   
            token_type='IO', #S1, S2, IO
            target_layers= [7, 7, 8, 8],
            target_heads= [3, 9, 6, 10],
            target_comp_idxs = [1,6,2,1],
            mover_layer= 9, 
            mover_head = 9,
            scales = [scale],
            with_C=False,
            raw_scores=False
        ).result(ws)
        rich.print("Scale", scale, "Inhibition Scores:", np.mean(inhib_scores))
        #s1_first_attn, s1_2nd_attn = s1_first_attn
        
        #print("S1 first attn", s1_first_attn.mean(), s1_first_attn.std())
        #print("S1 2nd attn", s1_2nd_attn.mean(), s1_2nd_attn.std())
        #avg attention to s1:
        #rich.print("S1 ATTN", s1_first_attn.mean(), s1_first_attn.std())

        #rich.print("S2 ATTN", s2_attn.mean(), s2_attn.std())

        #print("IO 1st attn", io_1st_io_2nd_attn[0].mean(), io_1st_io_2nd_attn[0].std())
        #print("IO 2nd attn", io_1st_io_2nd_attn[1].mean(), io_1st_io_2nd_attn[1].std())

        #avg attention to io:
        #rich.print("IO ATTN", io_1st_io_2nd_attn.mean(), io_1st_io_2nd_attn.std())
    #rich.print('inhib_score', (io_1st_io_2nd_attn-s1_first_attn).mean())


    #[1,2,3]-[3,4,5]


    from ioi_inhib_add_comp_patch_exp import add_in_multiple_component_patches
