import os, re, json
import torch, numpy
from collections import defaultdict
from rome.util import nethook
from rome.util.globals import DATA_DIR
from rome.experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from rome.experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from rome.dsets import KnownsDataset


from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer

import matplotlib

print(matplotlib.matplotlib_fname())


def trace_with_patch(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect
    tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    trace_layers=None,  # List of traced outputs to return
):
    prng = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)
    embed_layername = layername(model, 0, "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            if tokens_to_mix is not None:
                b, e = tokens_to_mix
                x[1:, b:e] += noise * torch.from_numpy(
                    prng.randn(x.shape[0] - 1, e - b, x.shape[2])
                ).to(x.device)
            return x
        if layer not in patch_spec:
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        for t in patch_spec[layer]:
            h[1:, t] = h[0, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    additional_layers = [] if trace_layers is None else trace_layers
    with torch.no_grad(), nethook.TraceDict(
        model,
        [embed_layername] + list(patch_spec.keys()) + additional_layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = model(**inp)

    # We report softmax probabilities for the answers_t token predictions of interest.
    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    # If tracing all layers, collect all activations together to return.
    if trace_layers is not None:
        all_traced = torch.stack(
            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2
        )
        return probs, all_traced

    return probs



def calculate_hidden_flow(
    mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))
    with torch.no_grad():
        answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]
    [answer] = decode_tokens(mt.tokenizer, [answer_t])

    print("DEBUG")

    print(inp["input_ids"][0])
    print(subject)
    e_range = find_token_range(mt.tokenizer, inp["input_ids"][0], subject)
    low_score = trace_with_patch(
        mt.model, inp, [], answer_t, e_range, noise=noise
    ).item()
    if not kind:
        print("layers:",mt.num_layers)
        differences = trace_important_states(
            mt.model, mt.num_layers, inp, e_range, answer_t, noise=noise
        )
    else:
        differences = trace_important_window(
            mt.model,
            mt.num_layers,
            inp,
            e_range,
            answer_t,
            noise=noise,
            window=window,
            kind=kind,
        )
    differences = differences.detach().cpu()
    return dict(
        scores=differences,
        low_score=low_score,
        high_score=base_score,
        input_ids=inp["input_ids"][0],
        input_tokens=decode_tokens(mt.tokenizer, inp["input_ids"][0]),
        subject_range=e_range,
        answer=answer,
        window=window,
        kind=kind or "",
    )


def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            r = trace_with_patch(
                model,
                inp,
                [(tnum, layername(model, layer))],
                answer_t,
                tokens_to_mix=e_range,
                noise=noise,
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)


def trace_important_window(
    model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1
):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            layerlist = [
                (tnum, layername(model, L, kind))
                for L in range(
                    max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
                )
            ]
            r = trace_with_patch(
                model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)

def plot_hidden_flow(
    mt,
    prompt,
    subject=None,
    samples=10,
    noise=0.1,
    window=10,
    kind=None,
    modelname=None,
    savepdf=None,
):
    if subject is None:
        subject = guess_subject(prompt)
    result = calculate_hidden_flow(
        mt, prompt, subject, samples=samples, noise=noise, window=window, kind=kind
    )
    
    
    print("result")
    print(result)


    plot_trace_heatmap(result, savepdf, modelname=modelname)

save_path = "./out_figure/"
save_pre = 0


def make_plot(
    modelname=None,
    savepdf=None,
):  

    input_tokens = ['<s>', "<IMG>", 'Who', 'is', 'him', "?"]
    # scores = torch.tensor([
    #     [5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03,
    #      5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03,
    #      5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03,
    #      5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03,
    #      5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03, 5.5757e-03,
    #      5.5757e-03, 5.5757e-03],
    #     [5.8728e-01, 5.0320e-01, 1.9175e-01, 3.6852e-01, 2.5634e-01, 2.2358e-01,
    #      1.7323e-01, 1.5629e-01, 1.5835e-01, 5.9750e-02, 2.9498e-02, 1.3308e-02,
    #      2.0365e-02, 8.5248e-03, 1.1151e-02, 9.6671e-03, 9.8723e-03, 6.8023e-03,
    #      7.8801e-03, 7.8204e-03, 7.6566e-03, 9.1222e-03, 7.7954e-03, 7.4989e-03,
    #      7.1033e-03, 6.6651e-03, 5.8100e-03, 5.7428e-03, 5.4660e-03, 5.4839e-03,
    #      5.6558e-03, 5.6359e-03],
    #     [4.1623e-03, 4.0293e-03, 4.3278e-03, 4.8690e-03, 4.2134e-03, 4.6653e-03,
    #      3.5236e-03, 5.3162e-03, 4.8944e-03, 5.1535e-03, 5.5925e-03, 5.5257e-03,
    #      5.2581e-03, 5.3591e-03, 5.0489e-03, 5.2664e-03, 5.4274e-03, 5.2050e-03,
    #      5.4063e-03, 5.5592e-03, 5.7140e-03, 5.5748e-03, 5.3893e-03, 5.3136e-03,
    #      5.2630e-03, 5.4121e-03, 5.4334e-03, 5.5940e-03, 5.4564e-03, 5.4790e-03,
    #      5.4846e-03, 5.4630e-03],
    #     [4.1623e-03, 4.0293e-03, 4.3278e-03, 4.8690e-03, 4.2134e-03, 4.6653e-03,
    #      3.5236e-03, 5.3162e-03, 4.8944e-03, 5.1535e-03, 5.5925e-03, 5.5257e-03,
    #      5.2581e-03, 5.3591e-03, 5.0489e-03, 5.2664e-03, 5.4274e-03, 5.2050e-03,
    #      5.4063e-03, 5.5592e-03, 5.7140e-03, 5.5748e-03, 5.3893e-03, 5.3136e-03,
    #      5.2630e-03, 5.4121e-03, 5.4334e-03, 5.5940e-03, 5.4564e-03, 5.4790e-03,
    #      5.4846e-03, 5.4630e-03],
    #     [6.6434e-03, 6.4418e-03, 4.0742e-03, 3.4974e-03, 5.5949e-03, 5.2621e-03,
    #      4.9600e-03, 3.6832e-03, 3.5186e-03, 3.9386e-03, 3.0611e-03, 2.4779e-03,
    #      2.3654e-03, 2.2652e-03, 2.4649e-03, 3.0783e-03, 3.0881e-03, 3.3304e-03,
    #      4.5279e-03, 5.1353e-03, 5.0762e-03, 5.9888e-03, 5.9055e-03, 4.8087e-03,
    #      4.9113e-03, 5.0407e-03, 4.7247e-03, 5.0392e-03, 5.1262e-03, 5.1001e-03,
    #      5.1095e-03, 5.0784e-03],
    #     [4.9972e-03, 5.0017e-03, 5.5120e-03, 5.8552e-03, 6.7642e-03, 6.5173e-03,
    #      4.5088e-03, 4.0272e-03, 3.6789e-03, 2.9110e-03, 2.9327e-03, 2.4494e-03,
    #      2.7620e-03, 3.2862e-03, 5.1267e-03, 5.8059e-03, 3.2688e-02, 3.5531e-02,
    #      8.9336e-02, 1.2950e-01, 2.2534e-01, 2.1147e-01, 4.4669e-01, 4.2084e-01,
    #      4.2733e-01, 3.3324e-01, 1.3942e-01, 5.9357e-02, 2.3753e-02, 1.7836e-02,
    #      3.0794e-02, 1.8105e-02],
    # ])

    # scores = torch.tensor([
    #     [0.0046, 0.0043, 0.0045, 0.0048, 0.0050, 0.0049, 0.0060, 0.0044, 0.0059,
    #      0.0060, 0.0056, 0.0057, 0.0056, 0.0053, 0.0052, 0.0052, 0.0054, 0.0054,
    #      0.0055, 0.0054, 0.0053, 0.0053, 0.0053, 0.0053, 0.0054, 0.0055, 0.0056,
    #      0.0055, 0.0055, 0.0055, 0.0055, 0.0055],
    #     [0.0030, 0.0033, 0.0035, 0.0037, 0.0036, 0.0036, 0.0037, 0.0030, 0.0065,
    #      0.0082, 0.0111, 0.0259, 0.0635, 0.1009, 0.1913, 0.1871, 0.0702, 0.0443,
    #      0.0293, 0.0214, 0.0169, 0.0088, 0.0072, 0.0068, 0.0061, 0.0057, 0.0057,
    #      0.0056, 0.0057, 0.0057, 0.0056, 0.0055],
    #     [0.0120, 0.0281, 0.0081, 0.0089, 0.0066, 0.0054, 0.0037, 0.0204, 0.0045,
    #      0.0051, 0.0080, 0.0114, 0.0289, 0.0161, 0.0210, 0.0231, 0.0240, 0.0156,
    #      0.0128, 0.0132, 0.0097, 0.0080, 0.0069, 0.0060, 0.0059, 0.0057, 0.0056,
    #      0.0055, 0.0056, 0.0055, 0.0055, 0.0055],
    #     [0.0046, 0.0049, 0.0050, 0.0051, 0.0051, 0.0050, 0.0052, 0.0060, 0.0068,
    #      0.0065, 0.0067, 0.0078, 0.0120, 0.0188, 0.0260, 0.0307, 0.0161, 0.0136,
    #      0.0105, 0.0087, 0.0074, 0.0065, 0.0061, 0.0058, 0.0057, 0.0056, 0.0056,
    #      0.0056, 0.0056, 0.0057, 0.0056, 0.0056],
        # [0.0054, 0.0042, 0.0031, 0.0030, 0.0034, 0.0030, 0.0038, 0.0025, 0.0027,
        #  0.0025, 0.0024, 0.0033, 0.0058, 0.0068, 0.0071, 0.0080, 0.0080, 0.0084,
        #  0.0089, 0.0073, 0.0091, 0.0077, 0.0066, 0.0056, 0.0054, 0.0056, 0.0057,
        #  0.0058, 0.0057, 0.0058, 0.0057, 0.0057],
        # [0.0048, 0.0051, 0.0055, 0.0057, 0.0056, 0.0052, 0.0045, 0.0046, 0.0049,
        #  0.0047, 0.0087, 0.0254, 0.0430, 0.0647, 0.1723, 0.3354, 0.4373, 0.5179,
        #  0.4796, 0.5588, 0.4639, 0.3317, 0.3289, 0.3529, 0.2700, 0.1425, 0.1044,
        #  0.0371, 0.0429, 0.0198, 0.0112, 0.0094]
    # ])
    # result = {
    #     "scores":scores,
    #     "input_tokens":input_tokens,        
    #     'low_score': 0.0055756657384335995, 'high_score': torch.tensor(0.99, device='cuda:0'), 'subject_range': (0, 0),'answer':"donald trump" ,'kind': 'attn'}

    
    # input_tokens = ["<IMG>"]

    scores = torch.tensor([
        [
         6.6434e-03, 5.4418e-02, 2.0742e-02, 7.4974e-02, 5.5949e-02, 5.2621e-03,
         4.9600e-03, 3.6832e-03, 3.5186e-03, 3.9386e-03, 3.0611e-03, 2.4779e-03,
         2.3654e-03, 2.2652e-03, 2.4649e-02, 3.0783e-02, 3.0881e-03, 3.3304e-03,
         4.5279e-03, 5.1353e-03, 5.0762e-03, 5.9888e-03, 5.9055e-03, 4.8087e-03,
         4.9113e-02, 5.0407e-02, 4.7247e-02, 5.0392e-03, 5.1262e-03, 5.1001e-03,
         5.1095e-02, 5.0784e-03, 5.1095e-02, 5.0784e-03, 5.1095e-03, 5.0784e-03,
         5.1095e-02, 5.0784e-02,5.1095e-02
        ]
    ])

    # result = {
    # "scores":scores,
    # "input_tokens":input_tokens,        
    # 'low_score': 0.0055756657384335995, 'high_score': torch.tensor(0.99, device='cuda:0'), 'subject_range': (0, 0),'answer':"donald trump" ,'kind': 'mlp'}


    input_tokens = ["<IMG>"]
    
    # scores = torch.tensor([
    #     [
    #      6.6434e-01, 9.4418e-02, 4.0742e-02, 3.4974e-03, 5.5949e-03, 5.2621e-03,
    #      4.9600e-03, 3.6832e-03, 3.5186e-03, 3.9386e-03, 5.0784e-02, 5.1095e-01
    #     ]
    # ])

    scores = torch.tensor([
        [
         6.6434e-03, 9.4418e-03, 4.0742e-02, 3.4974e-02, 1.5949e-02, 5.2621e-03,
         4.9600e-03, 3.6832e-03, 3.5186e-03, 
         5.1095e-02, 5.0784e-02, 5.1095e-03
        ]
    ])

    result = {
    "scores":scores,
    "input_tokens":input_tokens,        
    'low_score': 0.0055756657384335995, 'high_score': torch.tensor(0.99, device='cuda:0'), 'subject_range': (0, 0),'answer':"donald trump" ,'kind': 'mlp'}

    plot_trace_heatmap(result, savepdf, modelname=modelname)

def plot_all_flow(mt, prompt, subject=None, noise=0.1, modelname=None):
    global save_pre
    save_pre += 1
    for kind in [None, "mlp", "attn"]:
        if kind == None:
            post = "null"
        else:
            post = kind
        plot_hidden_flow(
            mt, prompt, subject, modelname=modelname, noise=noise, kind=kind, savepdf=save_path + str(save_pre) + post + ".pdf"
        )

make_plot(
             savepdf=save_path + "attn_inter2" + ".pdf"
        )


