import os, re, json
import torch, numpy
from collections import defaultdict
from rome.util import nethook
from rome.util.globals import DATA_DIR
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2Model, LlavaForConditionalGeneration, AutoProcessor
import torch
import random
from rome.experiments.casual_trace_for_MM import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from rome.experiments.casual_trace_for_MM import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
) 

# from rome.dsets import KnownsDataset
# from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer

import matplotlib

# from InstructBLIP2_model import load_m_model_and_preprocess
import torch
from PIL import Image
event_path = "/MMFact/Imsitu/ft_local/of500_images_resized/"


def trace_with_patch(
    mt,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect
    # tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    trace_layers=None,  # List of traced outputs to return
):
    prng = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)
    
    # print("HERE PRINT INP")
    # print(inp)

    embed_layername = layername(mt, 0, "visual_encoder", "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            b, e = 0 ,  16 * 16 
            # print("here",x.shape)

            x[1:, :, :, :] += noise * torch.from_numpy(
                prng.randn(x.shape[0] - 1, x.shape[1], 16, 16)
            ).to(x.device)
            return x
        if layer not in patch_spec:
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        # print(h.shape)

        for t in patch_spec[layer]:
            h[1:, t] = h[0, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    additional_layers = [] if trace_layers is None else trace_layers
    with torch.no_grad(), nethook.TraceDict(
        mt.model,
        [embed_layername] + list(patch_spec.keys()) + additional_layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = mt.model(**inp)

    # We report softmax probabilities for the answers_t token predictions of interest.
    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    # If tracing all layers, collect all activations together to return.
    if trace_layers is not None:
        all_traced = torch.stack(
            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2
        )
        return probs, all_traced

    return probs

def calculate_hidden_flow(
    mt, image, prompt, samples=10, noise=0.1, window=10, kind=None
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    inp = make_inputs(mt, [prompt] * (samples + 1), [image for _ in range((samples + 1))])

    with torch.no_grad():
        answer_t, base_score = [d[0] for d in predict_from_input(mt, inp)]

    [answer] = decode_tokens(mt.proccessor.tokenizer, [answer_t]) 

    low_score = trace_with_patch(
        mt, inp["inp"], [], answer_t, noise=noise
    ).item()

    if not kind:
        differences = trace_important_states(
            mt, inp["inp"], answer_t, noise=noise
        )
    else:
        differences = trace_important_window(
            mt,
            inp["inp"],
            answer_t,
            noise=noise,
            window=window,
            kind=kind,
        )
    result = [d.detach().cpu() for d in differences]
    return dict(
        scores=result,
        low_score=low_score,
        high_score=base_score,
        input_ids=inp["inp"]["input_ids"][0],
        input_tokens=decode_tokens(mt.proccessor.tokenizer, inp["inp"]["input_ids"][0]),
        # subject_range=e_range,
        answer=answer,
        window=window,
        kind=kind or "",
    )


def trace_important_states(mt, inp, answer_t, noise=0.1):
    patch_size = 16 * 16 + 1
    visual_tokens = 32
    ntoks = inp["input_ids"].shape[1]

    table_ve = []
    table_mi = []
    table_llm = []

    # for tnum in range(patch_size):
    row = []
    for layer in range (0, len(mt.visual_layer_names)):
        r = trace_with_patch(
            mt,
            inp,
            [(tnum, layername(mt, layer ,"visual_encoder")) for tnum in range(patch_size)],
            answer_t,
            noise=noise,
        )
        row.append(r)
    table_ve.append(torch.stack(row))

    for tnum in range(ntoks):
        row = []
        for layer in range (0, len(mt.interface_layer_names)):
            r = trace_with_patch(
                mt,
                inp,
                [(tnum, layername(mt, layer , "MM_interface"))],
                answer_t,
                noise=noise,
            )
            row.append(r)
        table_mi.append(torch.stack(row))


    for tnum in range(ntoks):
        row = []
        for layer in range (0, len(mt.llm_layer_names)):
            r = trace_with_patch(
                mt,
                inp,
                [(tnum, layername(mt, layer,"LLM"))],
                answer_t,
                noise=noise,
            )
            row.append(r)
        table_llm.append(torch.stack(row))

    return torch.stack(table_ve),torch.stack(table_mi),torch.stack(table_llm)



## for Instruct Blip
# def trace_important_window(
#     mt, inp, answer_t, kind, window=10, noise=0.1
# ):
#     patch_size =  16 * 16 + 1
#     visual_tokens = 32
#     ntoks = inp["input_ids"].shape[1]


#     table_ve = []
#     table_mi = []
#     table_llm = []

#     row = []
#     for layer in range (0, len(mt.visual_layer_names)):
#         layerlist = [
#             (L, layername(mt, layer,"visual_encoder", kind))
#             for L in range(
#                patch_size
#             )
#         ]
#         r = trace_with_patch(
#             mt, inp, layerlist, answer_t, noise=noise
#         )
#         row.append(r)
#     table_ve.append(torch.stack(row))
    

#     if kind == "attn":
#         all_num = visual_tokens + ntoks
#     elif kind == "mlp":
#         all_num = visual_tokens

#     for tnum in range(all_num):
#         row = []
#         for layer in range (0, len(mt.interface_layer_names)):
#                 layerlist = [
#                     (tnum, layername(mt, L, "MM_interface",  kind))
#                     for L in range(
#                         max(0, layer - window // 2), min(len(mt.interface_layer_names), layer - (-window // 2))
#                     )
#                 ]
#                 r = trace_with_patch(
#                     mt, inp, layerlist, answer_t, noise=noise
#                 )
#                 row.append(r)
#         table_mi.append(torch.stack(row))
    

#     for tnum in range(visual_tokens + ntoks):
#         row = []

#         for layer in range (0, len(mt.llm_layer_names)):


#             layerlist = [
#                 (tnum, layername(mt, L, "LLM", kind))
#                 for L in range(
#                     max(0, layer - window // 2), min(len(mt.llm_layer_names), layer - (-window // 2))
#                 )
#             ]
#             r = trace_with_patch(
#                 mt, inp, layerlist, answer_t, noise=noise
#             )
#             row.append(r)
#         table_llm.append(torch.stack(row))
    

#     return torch.stack(table_ve),torch.stack(table_mi),torch.stack(table_llm)


def trace_important_window(
    mt, inp, answer_t, kind, window=10, noise=0.1
):
    patch_size =  16 * 16 + 1
    visual_tokens = 32
    ntoks = inp["input_ids"].shape[1]


    table_ve = []
    table_mi = []
    table_llm = []

    row = []
    for layer in range (0, len(mt.visual_layer_names)):
        layerlist = [
            (L, layername(mt, layer,"visual_encoder", kind))
            for L in range(
               patch_size
            )
        ]
        r = trace_with_patch(
            mt, inp, layerlist, answer_t, noise=noise
        )
        row.append(r)
    table_ve.append(torch.stack(row))
    

    if kind == "attn":
        all_num = visual_tokens
    elif kind == "mlp":
        all_num = visual_tokens

    for tnum in range(all_num):
        row = []
        for layer in range (0, len(mt.interface_layer_names)):
                layerlist = [
                    (tnum, layername(mt, L, "MM_interface",  kind))
                    for L in range(
                        max(0, layer - window // 2), min(len(mt.interface_layer_names), layer - (-window // 2))
                    )
                ]
                r = trace_with_patch(
                    mt, inp, layerlist, answer_t, noise=noise
                )
                row.append(r)
        table_mi.append(torch.stack(row))
    

    for tnum in range(visual_tokens + ntoks):
        row = []

        for layer in range (0, len(mt.llm_layer_names)):


            layerlist = [
                (tnum, layername(mt, L, "LLM", kind))
                for L in range(
                    max(0, layer - window // 2), min(len(mt.llm_layer_names), layer - (-window // 2))
                )
            ]
            r = trace_with_patch(
                mt, inp, layerlist, answer_t, noise=noise
            )
            row.append(r)
        table_llm.append(torch.stack(row))
    

    return torch.stack(table_ve),torch.stack(table_mi),torch.stack(table_llm)

def calculate_random_c_center(result):
    llm_score = []
    ve_score = []
    mi_score = []

    ### ve
    ve = result["scores"][0][0]
    ve = list(ve.cpu().numpy())
    ve = [r/sum(ve) for r in ve]
    ve_layers = [0]*len(ve)


    print(len(ve))

    for i in range(50):
        this_layer_ran = random.randint(0,10000)

        count = 0
        for i in range(len(ve)):
            count += 10000 * ve[i]
            
            if count >= this_layer_ran:
                ve_layers[i]+=1
                break
    print(ve_layers)

    ### mi
    mi = result["scores"][1]
    mi = list(mi.cpu().numpy())
    mi_sum = 0

    for r in mi:
        mi_sum += sum(r)
    print(mi_sum)
    print(sum([r[0] for r in mi]))
    mi = [sum([r[i] for r in mi])/mi_sum for i in range(len(mi[0]))]
    mi_layers = [0]*len(mi)


    print(len(mi))

    for i in range(50):
        this_layer_ran = random.randint(0,10000)

        count = 0
        for i in range(len(mi)):
            count += 10000 * mi[i]
            
            if count >= this_layer_ran:
                mi_layers[i]+=1
                break
    print(mi_layers)

    ### llm

    llm = result["scores"][2]
    llm = list(llm.cpu().numpy())
    llm_sum = 0

    for r in llm:
        llm_sum += sum(r)
    print(llm_sum)
    print(sum([r[0] for r in llm]))
    llm = [sum([r[i] for r in llm])/llm_sum for i in range(len(llm[0]))]
    llm_layers = [0]*len(llm)


    print(len(llm))

    for i in range(50):
        this_layer_ran = random.randint(0,10000)

        count = 0
        for i in range(len(llm)):
            count += 10000 * llm[i]
            
            if count >= this_layer_ran:
                llm_layers[i]+=1
                break
    print(llm_layers)



def plot_hidden_flow(
    mt,
    prompt,
    image,
    subject=None,
    samples=10,
    noise=0.1,
    window=4,
    kind=None,
    modelname=None,
    savepdf=None,
):

    result = calculate_hidden_flow(
        mt, image, prompt, samples=samples, noise=noise, window=window, kind=kind
    )
    
    print("_*_"*20)
    print("result" + kind)
    print(result)

    calculate_random_c_center(result)


    plot_trace_heatmap(result, savepdf, modelname=modelname)

if __name__ == "__main__":
    print(matplotlib.matplotlib_fname())
    torch.set_grad_enabled(False)
    model_name = "Blip2OPT"
    print("changed V")

    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

    processor = Blip2Processor.from_pretrained("/blip2-opt-6.7b")
    model = Blip2ForConditionalGeneration.from_pretrained("/blip2-opt-6.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16)  # doctest: +IGNORE_RESULT

    # model = LlavaForConditionalGeneration.from_pretrained("Model/llava-hf--llava-1.5-7b-hf.main.05ae2434cbb430be33edcba0c5203e7023f785b7/",torch_dtype=torch.float16).to(device)
    # processor = AutoProcessor.from_pretrained("Model/llava-hf--llava-1.5-7b-hf.main.05ae2434cbb430be33edcba0c5203e7023f785b7/")


    prompt = "QUESTION: What action is the woman undertaking at outside that involves moving quickly without walking? \n ANSWER: "
    image = event_path+"running_89.jpg"

    # prompt = "Please use one word to answer the question. Which state does the place shown in the image belong to? The answer is: "
    # image = "MMFact/FB15k/FB15k-images/m.0q_xk/bing_23.jpg"

    # prompt = "Question: what animal is presented in the image?\n Answer: "
    # image = "MMFact/Oven/oven_00623726.JPEG"

    image = Image.open(image).convert("RGB")
    inputs = processor(images=image, text = prompt, return_tensors="pt").to(device, torch.float16)
    inputs["max_length"] = 42
    out = model.generate(**inputs)
    
    print("+"*30)
    print("PRE:")
    print(prompt)
    print(processor.batch_decode(out, skip_special_tokens=True))

    noise_level = 0.9
    print(f"Using noise level {noise_level}")


    mt = ModelAndTokenizer(model = model, tokenizer= processor.tokenizer, proccessor = processor, model_name = model_name)


    
    print(mt)
    print(len(mt.visual_layer_names))
    print(len(mt.interface_layer_names))
    print(len(mt.llm_layer_names))
    print("here = "*30)


    sample = {"image": image, "text_input": prompt}
    inp = make_inputs(mt, [prompt], [image])
    
    print(predict_from_input(mt,inp))



    save_path = "./out_figure_entity/"

    save_pre = 0

    # print(re)
    # plot_hidden_flow(
    #         mt, prompt, image, None, modelname=model_name, noise=noise_level, kind="attn", savepdf=save_path + str(save_pre)
    #     )

    plot_hidden_flow(
            mt, prompt, image, None, modelname=model_name, noise=noise_level, kind="mlp", savepdf=save_path + str(save_pre)
        )
