import argparse
import json
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"]="0"
import re
from collections import defaultdict

import numpy
import torch
# from datasets import load_dataset
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from dsets import KnownsDataset, LamatrexDataset
from rome.tok_dataset import (
    TokenizedDataset,
    dict_to_,
    flatten_masked_batch,
    length_collation,
)
from util import nethook
from util.globals import DATA_DIR
from util.runningstats import Covariance, tally
from pprint import pprint

from nltk.corpus import stopwords

import sys
from pathlib import Path
vit_path = Path("path/to/CausalPathTracing_for_ViT/main")
if vit_path not in sys.path:
    sys.path.insert(0, str(vit_path))
from lib.utils import get_model, get_data
from torch.utils.data import DataLoader
from pathlib import Path
def find_knowns_ids(model, data):
    print("Finding knowns ids")
    job_root = Path("path/to/CausalPathTracing_for_ViT/main") / "jobs" / str(model)
    if data == "officehome":
        job_root = Path("path/to/CausalPathTracing_for_ViT/main") / "jobs_oh" / str(model)
    ids = []
    dir_list = os.listdir(job_root)
    print(len(dir_list))
    for entry in tqdm(job_root.iterdir()):
        if entry.is_dir():
            name = entry.name
            idx = name.split("_")[1]

            results_dir = entry / "results"
            # Check if the results directory exists and is not empty
            if results_dir.exists() and any(results_dir.iterdir()):
                ids.append(int(idx))
    print("Done finding knowns ids")
    print(f"Found {len(ids)} knowns ids")
    bulk = len(ids) // 10
    ids = sorted(ids)
    i = 1
    if model == "deit_tiny_patch16_224" and data == "imagenet":
        print("run deit_tiny_patch16_224, imagenet")
        if 193294 in ids:
            print("removing 193294")
            ids.remove(193294)
    return sorted(ids)
def main():
    parser = argparse.ArgumentParser(description="Causal Tracing")

    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)

    def parse_noise_rule(code):
        if code in ["m", "s"]:
            return code
        elif re.match("^[uts][\d\.]+", code):
            return code
        else:
            return float(code)

    aa(
        "--model_name",
        default="AlgorithmicResearchGroup/gpt2-xs",
        choices=[       #params, size 
            "gpt2-xl",  #1.5B, 3GB
            "EleutherAI/gpt-j-6B", #6B, 12GB
            "EleutherAI/gpt-neox-20b", #20B, 40GB
            "gpt2-large", #774M, 1.5GB
            "gpt2-medium",#345M, 700MB
            "gpt2",       #124M, 250MB
            "sshleifer/tiny-gpt2", #0.1M(103K), 2.5MB
            "AlgorithmicResearchGroup/gpt2-xs",
            "EleutherAI/pythia-14m",
            "EleutherAI/pythia-1b",
            "vit_tiny_patch16_224",
            "deit_tiny_patch16_224",
            "vit_base_patch16_224"
        ],
    )
    aa("--fact_file", default=None)
    aa("--output_dir", default="results/{model_name}/{fact_file}-causal_trace")
    aa("--noise_level", default="s3", type=parse_noise_rule)
    aa("--replace", default=0, type=int)
    aa("--token_range", type=str, required=True) #corrupt_all
    aa("--window", default=1, type=int)
    aa("--image_dataset", default=None, type=str)
    aa("--except_stopword", default=False, action="store_true", help="Exclude stopwords from processing")
    args = parser.parse_args()
    if "google" in args.model_name:
        os.environ["TF_ENABLE_ONEDNN_OPTS"]="0"
    if "gpt2-xs" in args.model_name:
        modeldir = f'r{args.replace}_gpt2-xs'
    elif "EleutherAI" in args.model_name:
        modeldir = f'r{args.replace}_{args.model_name.split("/")[-1]}'
    else:
        modeldir = f'r{args.replace}_{args.model_name.replace("/", "_")}'
    modeldir = f"w{args.window}_n{args.noise_level}_" + modeldir
    if args.fact_file:
        output_dir = args.output_dir.format(model_name=modeldir, fact_file=args.fact_file)
    else:
        if args.image_dataset:
            output_dir = args.output_dir.format(model_name=modeldir, fact_file=args.image_dataset)
        else:
            output_dir = args.output_dir.format(model_name=modeldir, fact_file="knowns1000")
    result_dir = f"{output_dir}/cases"
    pdf_dir = f"{output_dir}/pdfs"
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(pdf_dir, exist_ok=True)

    # Half precision to let the 20b model fit.
    torch_dtype = torch.float16 if "20b" in args.model_name else None



    if args.fact_file == "knowns1000":
        knowns = KnownsDataset(DATA_DIR)
    elif "lama" == args.fact_file:
        knowns = LamatrexDataset(DATA_DIR, reverse=False)
    elif False and "/" in args.fact_file: #aryaman/causalgym
        print("using dataset named : ", args.fact_file)
        knowns = load_dataset(args.fact_file)
    else:
        if args.image_dataset:
            # Load ImageNet dataset
            dataset, num_classes = get_data(args.image_dataset)
        else: 
            with open(args.fact_file) as f:
                knowns = json.load(f)
            print(len(knowns))

    try:
        mt = ModelAndTokenizer(args.model_name, torch_dtype=torch_dtype)
    except Exception as e:
        print(f"Error loading model {args.model_name}: {e}")
        print("Trying to load ViT model")
        mt = get_model(model_name=args.model_name, num_classes=num_classes, dataset_name=args.image_dataset)
        
    if args.except_stopword:
        def get_stopwords(mt):
            stopword_list = stopwords.words('english')
            
            stopword_ids = []
            print("calculating stopwords...")
            for stopword in tqdm(stopword_list):
                token_ids = mt.tokenizer.encode(' '+stopword, add_special_tokens=False)
                if len(token_ids) == 1:
                    stopword_ids.append(token_ids[0])
                token_ids = mt.tokenizer.encode(stopword, add_special_tokens=False)
                if len(token_ids) == 1:
                    stopword_ids.append(token_ids[0])
                    
            mt.stwd_ids =  sorted(list(set(stopword_ids)))
        get_stopwords(mt)



    noise_level = args.noise_level
    uniform_noise = False
    if isinstance(noise_level, str):
        if noise_level.startswith("s"):
            # Automatic spherical gaussian
            factor = float(noise_level[1:]) if len(noise_level) > 1 else 1.0
            if args.image_dataset:
                noise_level = factor * collect_IM_embedding_std(
                    mt, dataset
                )
                print(f"Using noise_level {noise_level} to match model times {factor}")
            else:
                noise_level = factor * collect_embedding_std(
                    mt, [k["subject"] for k in knowns]
                )
            print(f"Using noise_level {noise_level} to match model times {factor}")
        elif noise_level == "m":
            # Automatic multivariate gaussian
            noise_level = collect_embedding_gaussian(mt)
            print(f"Using multivariate gaussian to match model noise")
        elif noise_level.startswith("t"):
            # Automatic d-distribution with d degrees of freedom
            degrees = float(noise_level[1:])
            noise_level = collect_embedding_tdist(mt, degrees)
        elif noise_level.startswith("u"):
            uniform_noise = True
            noise_level = float(noise_level[1:])
    correct = 0 
    wrong = 0
    iter = 0
    if args.image_dataset:
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
        known_ids = find_knowns_ids(args.model_name, args.image_dataset)
        print(f"len(known_ids): {len(known_ids)}")
        for batch_idx, (images, labels) in tqdm(enumerate(dataloader)):
            if batch_idx not in known_ids:
                continue

            images = images.cuda()
            labels = labels.cuda()

            for kind in None, "mlp", "attn":
                kind_suffix = f"_{kind}" if kind else ""
                filename = f"{result_dir}/knowledge_{batch_idx}{kind_suffix}.npz"
                if not os.path.isfile(filename):
                    result = IM_calculate_hidden_flow(
                        mt,
                        images,
                        images,
                        expect=labels,
                        window=args.window,
                        kind=kind,
                        noise=noise_level,
                        uniform_noise=uniform_noise,
                        replace=args.replace,
                        token_range=args.token_range,
                    )
                    numpy_result = {
                        k: v.detach().cpu().numpy() if torch.is_tensor(v) else v
                        for k, v in result.items()
                    }
                    numpy.savez(filename, **numpy_result)
                else:
                    numpy_result = numpy.load(filename, allow_pickle=True)
                if not numpy_result["correct_prediction"]:
                    wrong += 1
                    print(f"Wrong Skipping {batch_idx}")
                    continue
                correct += 1
                plot_result = dict(numpy_result)
                # plot_result["kind"] = kind
                # pdfname = f'{pdf_dir}/{str(numpy_result["answer"]).strip()}_{batch_idx}{kind_suffix}.pdf'
                # plot_trace_heatmap(plot_result, savepdf=pdfname)
            print(iter, end=" ")
            iter += 1
            if iter % 10 == 0:
                print()

    else:
        for iiii, knowledge in tqdm(enumerate(knowns)):
            known_id = knowledge["known_id"]
            if args.fact_file == "lama":
                known_id = iiii
            assert iiii == known_id
            
            for kind in None, "mlp", "attn":
                kind_suffix = f"_{kind}" if kind else ""
                filename = f"{result_dir}/knowledge_{known_id}{kind_suffix}.npz"
                if not os.path.isfile(filename):
                    result = calculate_hidden_flow(
                        mt,
                        knowledge["prompt"],
                        knowledge["subject"],
                        expect=knowledge["attribute"],
                        window=args.window,
                        kind=kind,
                        noise=noise_level,
                        uniform_noise=uniform_noise,
                        replace=args.replace,
                        token_range=args.token_range,
                        except_stopword = args.except_stopword
                    )
                    numpy_result = {
                        k: v.detach().cpu().numpy() if torch.is_tensor(v) else v
                        for k, v in result.items()
                    }
                    numpy.savez(filename, **numpy_result)
                else:
                    numpy_result = numpy.load(filename, allow_pickle=True)
                if iiii in [16, 104, 107, 217, 273, 658, 787, 836, 878, 1010, 1196]:
                    # print(iiii)
                    # print(dict(numpy_result))
                    pass
                if not numpy_result["correct_prediction"]:
                    wrong += 1
                    continue
                correct += 1
                plot_result = dict(numpy_result)
                plot_result["kind"] = kind
                # pdfname = f'{pdf_dir}/{str(numpy_result["answer"]).strip()}_{known_id}{kind_suffix}.pdf'
                if known_id > 200:
                    continue
                # plot_trace_heatmap(plot_result, savepdf=pdfname)
                
    print(f"correct: {correct}, wrong: {wrong}")

def trace_with_patch(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect
    tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    uniform_noise=False,
    replace=False,  # True to replace with instead of add noise
    trace_layers=None,  # List of traced outputs to return
):
    """
    Runs a single causal trace.  Given a model and a batch input where
    the batch size is at least two, runs the batch in inference, corrupting
    a the set of runs [1...n] while also restoring a set of hidden states to
    the values from an uncorrupted run [0] in the batch.

    The convention used by this function is that the zeroth element of the
    batch is the uncorrupted run, and the subsequent elements of the batch
    are the corrupted runs.  The argument tokens_to_mix specifies an
    be corrupted by adding Gaussian noise to the embedding for the batch
    inputs other than the first element in the batch.  Alternately,
    subsequent runs could be corrupted by simply providing different
    input tokens via the passed input batch.

    Then when running, a specified set of hidden states will be uncorrupted
    by restoring their values to the same vector that they had in the
    zeroth uncorrupted run.  This set of hidden states is listed in
    states_to_patch, by listing [(token_index, layername), ...] pairs.
    To trace the effect of just a single state, this can be just a single
    token/layer pair.  To trace the effect of restoring a set of states,
    any number of token indices and layers can be listed.
    """

    rs = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    if uniform_noise:
        prng = lambda *shape: rs.uniform(-1, 1, shape)
    else:
        prng = lambda *shape: rs.randn(*shape)

    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)

    if tokens_to_mix == "image":
        embed_layername = IM_layername(model, 0, "embed")
    else:
        embed_layername = layername(model, 0, "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    if isinstance(noise, float):
        noise_fn = lambda x: noise * x
    else:
        noise_fn = noise

    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            if tokens_to_mix is not None and tokens_to_mix != "image":
                b, e = tokens_to_mix
                noise_data = noise_fn(
                    torch.from_numpy(prng(x.shape[0] - 1, e - b, x.shape[2]))
                ).to(x.device)
                if replace:
                    x[1:, b:e] = noise_data
                else:
                    x[1:, b:e] += noise_data
            # print(f"Corrupting {x[1:, b:e].shape};{tokens_to_mix} with noise {noise}") #Corrupting torch.Size([10, 7, 384]) with noise 0.23906993865966797
            elif tokens_to_mix == "image":
                # Corrupt the image embedding
                b, e = 0, x.shape[1]
                x[1:, b:e] = noise_fn(
                    torch.from_numpy(prng(x.shape[0] - 1, e - b, x.shape[2]))
                ).to(x.device)
            # assert (x[0].shape[0] == e + 1) and (b == 0)
            return x
        if layer not in patch_spec:
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        for t in patch_spec[layer]:
            h[1:, t] = h[0, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    additional_layers = [] if trace_layers is None else trace_layers
    with torch.no_grad(), nethook.TraceDict(
        model,
        [embed_layername] + list(patch_spec.keys()) + additional_layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = model(**inp)

    # We report softmax probabilities for the answers_t token predictions of interest.
    # 오염된 B-1개의 시행했을 때, 특정 레이어의 특정 토큰을 복원했을때 answer토큰의 확률값.
    if tokens_to_mix == "image":
        # Image의 경우 바로 [Batch, Token]
        # probs = torch.softmax(outputs_exp["logits"][1:, :], dim=1).mean(dim=0)[answers_t]
        # input이 하나인 경우 [1:, :] 는 nan임.
        probs = torch.softmax(outputs_exp["logits"][:,:], dim=1).mean(dim=0)[answers_t]
    else:
        probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    # If tracing all layers, collect all activations together to return.
    if trace_layers is not None:
        all_traced = torch.stack(
            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2
        )
        return probs, all_traced

    return probs


def trace_with_repatch(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    states_to_unpatch,  # A list of (token index, layername) triples to re-randomize
    answers_t,  # Answer probabilities to collect
    tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    uniform_noise=False,
):
    rs = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    if uniform_noise:
        prng = lambda *shape: rs.uniform(-1, 1, shape)
    else:
        prng = lambda *shape: rs.randn(*shape)
    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)
    unpatch_spec = defaultdict(list)
    for t, l in states_to_unpatch:
        unpatch_spec[l].append(t)

    embed_layername = layername(model, 0, "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            if tokens_to_mix is not None:
                b, e = tokens_to_mix
                x[1:, b:e] += noise * torch.from_numpy(
                    prng(x.shape[0] - 1, e - b, x.shape[2])
                ).to(x.device)
            return x
        if first_pass or (layer not in patch_spec and layer not in unpatch_spec):
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        for t in patch_spec.get(layer, []):
            h[1:, t] = h[0, t]
        for t in unpatch_spec.get(layer, []):
            h[1:, t] = untuple(first_pass_trace[layer].output)[1:, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    for first_pass in [True, False] if states_to_unpatch else [False]:
        with torch.no_grad(), nethook.TraceDict(
            model,
            [embed_layername] + list(patch_spec.keys()) + list(unpatch_spec.keys()),
            edit_output=patch_rep,
        ) as td:
            outputs_exp = model(**inp)
            if first_pass:
                first_pass_trace = td

    # We report softmax probabilities for the answers_t token predictions of interest.
    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    return probs


def calculate_hidden_flow(
    mt,
    prompt,
    subject,
    samples=10,
    noise=0.1,
    token_range=None,
    uniform_noise=False,
    replace=False,
    window=10,
    kind=None,
    expect=None,
    except_stopword=False,
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))
    with torch.no_grad():

        answer_t, base_score = [d[0] for d in predict_from_input(mt, inp, except_stopword)]
    [answer] = decode_tokens(mt.tokenizer, [answer_t])
    
    if expect is not None and (answer.strip() != expect and expect not in answer):
        return dict(correct_prediction=False)
    e_range = find_token_range(mt.tokenizer, inp["input_ids"][0], subject)
    if token_range == "subject_last":
        token_range = [e_range[1] - 1]
    elif token_range == "corrupt_all":
        e_range = (0, inp["input_ids"].shape[1]-1)
        token_range = None
    elif token_range is not None:
        raise ValueError(f"Unknown token_range: {token_range}")
    # print("arguements: ", mt.model, inp, e_range, answer_t, noise, uniform_noise, replace, token_range)
    low_score = trace_with_patch(
    #         model,  # The model
    # inp,  # A set of inputs
    # states_to_patch,  # A list of (token index, layername) triples to restore
    # answers_t,  # Answer probabilities to collect
    # tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    # noise=0.1,  # Level of noise to add
    # uniform_noise=False,
    # replace=False,  # True to replace with instead of add noise
    # trace_layers=None, 
        mt.model, inp, [], answer_t, e_range, noise=noise, uniform_noise=uniform_noise
    ).item()
    if not kind:
        differences = trace_important_states(
            mt.model,
            mt.num_layers,
            inp,
            e_range,
            answer_t,
            noise=noise,
            uniform_noise=uniform_noise,
            replace=replace,
            token_range=token_range,
        )
    else:
        differences = trace_important_window(
            mt.model,
            mt.num_layers,
            inp,
            e_range,
            answer_t,
            kind=kind,
            window=window,
            noise=noise,
            uniform_noise=uniform_noise,
            replace=replace,
            token_range=token_range,
        )
    differences = differences.detach().cpu()
    return dict(
        scores=differences,
        low_score=low_score,
        high_score=base_score,
        input_ids=inp["input_ids"][0],
        input_tokens=decode_tokens(mt.tokenizer, inp["input_ids"][0]),
        subject_range=e_range,
        answer=answer,
        window=window,
        correct_prediction=True,
        kind=kind or "",
    )

def IM_calculate_hidden_flow(
    mt,
    prompt,
    subject,
    samples=10,
    noise=0.1,
    token_range=None,
    uniform_noise=False,
    replace=False,
    window=10,
    kind=None,
    expect=None,
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    prompt = prompt[0]
    batch_images = prompt.unsqueeze(0).repeat(samples + 1, 1, 1, 1) # [B, C, H, W]
    # print(batch_images.shape) #torch.Size([11, 3, 224, 224])
    inp = {"pixel_values": batch_images}

    with torch.no_grad():
        out = mt.model(**inp)["logits"][0]
        probs = torch.softmax(out, dim=-1)
        base_score, answer_t = torch.max(probs, dim=-1)
        # print(f"base_score: {base_score}, answer_t: {answer_t}")
    # [answer] = decode_tokens(mt.tokenizer, [answer_t])
    # print(f"answer_t: {answer_t}, Expect: {expect}")
    # print(f"type(answer_t): {type(answer_t)}", f"type(expect): {type(expect)}")
    if expect is not None and answer_t.item() != expect.item():
        
        return dict(correct_prediction=False)
    else:
        # return dict(correct_prediction=Treu)
        pass

    # print(f"Prompt: {prompt}")
    # print(f"Answer: {answer}, Expect: {expect}")
    # print("arguements: ", mt.model, inp, e_range, answer_t, noise, uniform_noise, replace, token_range)
    low_score = trace_with_patch(
    #         model,  # The model
    # inp,  # A set of inputs
    # states_to_patch,  # A list of (token index, layername) triples to restore
    # answers_t,  # Answer probabilities to collect
    # tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    # noise=0.1,  # Level of noise to add
    # uniform_noise=False,
    # replace=False,  # True to replace with instead of add noise
    # trace_layers=None, 
        mt.model, inp, [], answer_t, "image", noise=noise, uniform_noise=uniform_noise
    ).item()
    
    if not kind:
        differences = trace_important_states(
            mt.model,
            len(mt.model.model.blocks),
            inp,
            "image",
            answer_t,
            noise=noise,
            uniform_noise=uniform_noise,
            replace=replace,
            token_range=token_range,
        )
    else:
        differences = trace_important_window(
            mt.model,
            len(mt.model.model.blocks),
            inp,
            "image",
            answer_t,
            kind=kind,
            window=window,
            noise=noise,
            uniform_noise=uniform_noise,
            replace=replace,
            token_range=token_range,
        )
    differences = differences.detach().cpu()
    return dict(
        scores=differences,
        low_score=low_score,
        high_score=base_score,
        input_ids=None,
        input_tokens=None,
        subject_range="image",
        answer=answer_t,
        window=window,
        correct_prediction=True,
        kind=kind or "",
    )


def trace_important_states(
    model,
    num_layers,
    inp,
    e_range,
    answer_t,
    noise=0.1,
    uniform_noise=False,
    replace=False,
    token_range=None,
):  
    if e_range == "image":
        token_range = range(1, 14*14) 
        token_range = range(1)
    else:
        ntoks = inp["input_ids"].shape[1]
    table = []
    if token_range is None:
        token_range = range(ntoks)
    for tnum in token_range:
        row = []
        for layer in range(num_layers):
            r = trace_with_patch(
                model,
                inp,
                [(a, IM_layername(model, layer)) for a in range(1, 14*14) ] if e_range == "image" else [(tnum, layername(model, layer))],
                answer_t,
                tokens_to_mix=e_range,
                noise=noise,
                uniform_noise=uniform_noise,
                replace=replace,
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)


def trace_important_window(
    model,
    num_layers,
    inp,
    e_range,
    answer_t,
    kind,
    window=10,
    noise=0.1,
    uniform_noise=False,
    replace=False,
    token_range=None,
):
    if e_range == "image":
        token_range = range(1, 14*14)
        token_range = range(1)
    else:
        ntoks = inp["input_ids"].shape[1]
    table = []

    if token_range is None:
        token_range = range(ntoks)
    for tnum in token_range:
        row = []
        for layer in range(num_layers):
            layerlist = [
                (tnum, layername(model, L, kind))
                for L in range(
                    max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
                )
            ]
            if e_range == "image":
                layerlist = [
                    (a, IM_layername(model, L, kind))
                    for L in range(
                        max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
                    )
                    for a in range(1, 14*14)
                ]
            r = trace_with_patch(
                model,
                inp,
                layerlist,
                answer_t,
                tokens_to_mix=e_range,
                noise=noise,
                uniform_noise=uniform_noise,
                replace=replace,
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)


class ModelAndTokenizer:
    """
    An object to hold on to (or automatically download and hold)
    a GPT-style language model and tokenizer.  Counts the number
    of layers.
    """

    def __init__(
        self,
        model_name=None,
        model=None,
        tokenizer=None,
        low_cpu_mem_usage=False,
        torch_dtype=None,
    ):
        if tokenizer is None:
            assert model_name is not None
            if "erwanf" in model_name:
                tokenizer = AutoTokenizer.from_pretrained("/data8/baek/.cache/huggingface/hub/models--erwanf--gpt2-mini/snapshots/f12cc7ee54aa4d3e6366597f72ff7acb39b0ab3a")  
            elif "AlgorithmicResearchGroup" in model_name:
                tokenizer = AutoTokenizer.from_pretrained("/data8/baek/.cache/huggingface/hub/models--AlgorithmicResearchGroup--gpt2-xs/snapshots/7b4f6aa9dc7fe02996e78f14f6b7c138f95cf380")
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_name)
        if model is None:
            assert model_name is not None
            if "erwanf" in model_name:
                model = AutoModelForCausalLM.from_pretrained("/data8/baek/.cache/huggingface/hub/models--erwanf--gpt2-mini/snapshots/f12cc7ee54aa4d3e6366597f72ff7acb39b0ab3a", low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype)
            elif "AlgorithmicResearchGroup" in model_name:
                model = AutoModelForCausalLM.from_pretrained("/data8/baek/.cache/huggingface/hub/models--AlgorithmicResearchGroup--gpt2-xs/snapshots/7b4f6aa9dc7fe02996e78f14f6b7c138f95cf380", low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype)
            else:    
                model = AutoModelForCausalLM.from_pretrained(
                    model_name, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype
                )
            nethook.set_requires_grad(False, model)
            model.eval().cuda()
        self.tokenizer = tokenizer
        self.model = model
        self.layer_names = [
            n
            for n, m in model.named_modules()
            if (re.match(r"^(transformer|gpt_neox|model)\.(h|layers|layers)\.\d+$", n))
        ]
        self.num_layers = len(self.layer_names)
        self.stwd_mask = None
        self.stwd_ids = None

    def __repr__(self):
        return (
            f"ModelAndTokenizer(model: {type(self.model).__name__} "
            f"[{self.num_layers} layers], "
            f"tokenizer: {type(self.tokenizer).__name__})"
        )


def layername(model, num, kind=None):
    if hasattr(model, "transformer"):
        if kind == "embed":
            return "transformer.wte"
        return f'transformer.h.{num}{"" if kind is None else "." + kind}'
    if hasattr(model, "gpt_neox"):
        if kind == "embed":
            return "gpt_neox.embed_in"
        if kind == "attn":
            kind = "attention"
        return f'gpt_neox.layers.{num}{"" if kind is None else "." + kind}'
    if hasattr(model, "model"): # for Gemma3
        if kind == "embed":
            return "model.embed_tokens"
        if kind == "attn":
            kind = "self_attn"
        return f"model.layers.{num}{'' if kind is None else '.' + kind}"
    assert False, "unknown transformer structure"

def IM_layername(model, num, kind=None, check_subl_inp=False, verf_type="all"):
    if hasattr(model, "transformer"):
        if kind == "embed":
            return "transformer.wte"
        
        name = f'transformer.h.{num}'
        if verf_type=="all":
            if check_subl_inp:
                if kind == "mlp":
                    ln_name = "ln_2"
                elif kind == "attn":
                    ln_name = "ln_1"
                else:
                    import pdb; pdb.set_trace()
                name += "" if kind is None else "." + ln_name
            else:
                name += "" if kind is None else "." + kind
        else:
            if check_subl_inp:
                print("Check!! -> check_subl_inp option cannot operate!!")
            if verf_type=="attn":
                if "attn_" in kind:
                    name+=".ln_1"
                else:
                    name += "" if kind is None else "." + kind
            else:
                import pdb; pdb.set_trace()
        return name
    if hasattr(model, "gpt_neox"):
        if verf_type != "all":
            import pdb; pdb.set_trace()
        if kind == "embed":
            return "gpt_neox.embed_in"
        if kind == "attn":
            kind = "attention"
        return f'gpt_neox.layers.{num}{"" if kind is None else "." + kind}'
    
    # Image Model is wrapped in ModelWrapper
    if hasattr(model, "model"):
        if kind == "embed":
            return "model.patch_embed"
        else:
            return f'model.blocks.{num}{"" if kind is None else "." + kind}'
            import pdb; pdb.set_trace()
            raise NotImplementedError("Not implemented yet")
        
    if hasattr(model, "patch_embed"):
        if kind == "embed":
            return "model.patch_embed"
        return f'model.blocks.{num}{"" if kind is None else "." + kind}'

    assert False, "unknown transformer structure"

def guess_subject(prompt):
    return re.search(r"(?!Wh(o|at|ere|en|ich|y) )([A-Z]\S*)(\s[A-Z][a-z']*)*", prompt)[
        0
    ].strip()


def plot_hidden_flow(
    mt,
    prompt,
    subject=None,
    samples=10,
    noise=0.1,
    uniform_noise=False,
    window=10,
    kind=None,
    savepdf=None,
):
    if subject is None:
        subject = guess_subject(prompt)
    result = calculate_hidden_flow(
        mt,
        prompt,
        subject,
        samples=samples,
        noise=noise,
        uniform_noise=uniform_noise,
        window=window,
        kind=kind,
    )
    plot_trace_heatmap(result, savepdf)


def plot_trace_heatmap(result, savepdf=None, title=None, xlabel=None, modelname=None):
    differences = result["scores"]
    low_score = result["low_score"]
    answer = result["answer"]
    kind = (
        None
        if (not result["kind"] or result["kind"] == "None")
        else str(result["kind"])
    )
    window = result.get("window", 10)
    if result["subject_range"] == "image":
        labels = [
            i for i in range(0, 14 * 14 + 1)
        ]
        result["subject_range"] = (1, 14 * 14)
    else:
        labels = list(result["input_tokens"])
    for i in range(*result["subject_range"]):
        labels[i] = str(labels[i]) + "*"

    with plt.rc_context(rc={"font.family": "Times New Roman"}):
        fig, ax = plt.subplots(figsize=(3.5, 2), dpi=200)
        h = ax.pcolor(
            differences,
            cmap={None: "Purples", "None": "Purples", "mlp": "Greens", "attn": "Reds"}[
                kind
            ],
            vmin=low_score,
        )
        ax.invert_yaxis()
        ax.set_yticks([0.5 + i for i in range(len(differences))])
        ax.set_xticks([0.5 + i for i in range(0, differences.shape[1] - 6, 5)])
        ax.set_xticklabels(list(range(0, differences.shape[1] - 6, 5)))
        ax.set_yticklabels(labels)
        if not modelname:
            modelname = "GPT"
        if not kind:
            ax.set_title("Impact of restoring state after corrupted input")
            ax.set_xlabel(f"single restored layer within {modelname}")
        else:
            kindname = "MLP" if kind == "mlp" else "Attn"
            ax.set_title(f"Impact of restoring {kindname} after corrupted input")
            ax.set_xlabel(f"center of interval of {window} restored {kindname} layers")
        cb = plt.colorbar(h)
        if title is not None:
            ax.set_title(title)
        if xlabel is not None:
            ax.set_xlabel(xlabel)
        elif answer is not None:
            # The following should be cb.ax.set_xlabel, but this is broken in matplotlib 3.5.1.
            cb.ax.set_title(f"p({str(answer).strip()})", y=-0.16, fontsize=10)
        if savepdf:
            os.makedirs(os.path.dirname(savepdf), exist_ok=True)
            plt.savefig(savepdf, bbox_inches="tight")
            plt.close()
        else:
            plt.show()


def plot_all_flow(mt, prompt, subject=None):
    for kind in ["mlp", "attn", None]:
        plot_hidden_flow(mt, prompt, subject, kind=kind)


# Utilities for dealing with tokens
def make_inputs(tokenizer, prompts, device="cuda"):
    token_lists = [tokenizer.encode(p) for p in prompts]
    maxlen = max(len(t) for t in token_lists)
    if "[PAD]" in tokenizer.all_special_tokens:
        pad_id = tokenizer.all_special_ids[tokenizer.all_special_tokens.index("[PAD]")]
    else:
        pad_id = 0
    input_ids = [[pad_id] * (maxlen - len(t)) + t for t in token_lists]
    # position_ids = [[0] * (maxlen - len(t)) + list(range(len(t))) for t in token_lists]
    attention_mask = [[0] * (maxlen - len(t)) + [1] * len(t) for t in token_lists]
    return dict(
        input_ids=torch.tensor(input_ids).to(device),
        #    position_ids=torch.tensor(position_ids).to(device),
        attention_mask=torch.tensor(attention_mask).to(device),
    )


def decode_tokens(tokenizer, token_array):
    if hasattr(token_array, "shape") and len(token_array.shape) > 1:
        return [decode_tokens(tokenizer, row) for row in token_array]
    return [tokenizer.decode([t]) for t in token_array]


def find_token_range(tokenizer, token_array, substring):
    toks = decode_tokens(tokenizer, token_array)
    whole_string = "".join(toks)
    char_loc = whole_string.index(substring)
    loc = 0
    tok_start, tok_end = None, None
    for i, t in enumerate(toks):
        loc += len(t)
        if tok_start is None and loc > char_loc:
            tok_start = i
        if tok_end is None and loc >= char_loc + len(substring):
            tok_end = i + 1
            break
    return (tok_start, tok_end)


def predict_token(mt, prompts, return_p=False):
    inp = make_inputs(mt.tokenizer, prompts)
    preds, p = predict_from_input(mt, inp)
    result = [mt.tokenizer.decode(c) for c in preds]
    if return_p:
        result = (result, p)
    return result


def predict_from_input(mt, inp, except_stopword=False):
    out = mt.model(**inp)["logits"]
    probs = torch.softmax(out[:, -1], dim=1)
            # scores_normal = torch.softmax(out[:, -1, :], dim=1)[0]
            # stwd_mask = torch.ones(scores_normal.shape[0], dtype=torch.bool, device=scores_normal.device)
            # mt.stwd_mask[mt.stwd_ids] = False 
                    
            # desc_idx = torch.argsort(scores_normal, dim=0, descending=True)
            # sorted_stwd_mask = mt.stwd_mask[desc_idx]
            # answer_t = desc_idx[sorted_stwd_mask][0].unsqueeze(0)
    if except_stopword:
        mt.stwd_mask = torch.ones(probs.shape[1], dtype=torch.bool, device=probs.device)
        mt.stwd_mask[mt.stwd_ids] = False
        # mt.stwd_mask = mt.stwd_mask.unsqueeze(0).expand_as(probs)


        probs = probs.masked_fill(~mt.stwd_mask.unsqueeze(0), float("-inf"))
        # desc_idx = torch.argsort(probs, dim=1, descending=True)
        
        # sorted_stwd_mask = mt.stwd_mask.gather(1,desc_idx)
        # preds = desc_idx[sorted_stwd_mask][0]
        # p = probs[:, preds][0]
        # Select predicted token indices and their probabilities
        preds = torch.argmax(probs, dim=1)
        probs = torch.softmax(probs, dim=1)
        p = probs[torch.arange(probs.size(0)), preds]

        
        # import pdb; pdb.set_trace()
    else:
        p, preds = torch.max(probs, dim=1)
    return preds, p


def collect_embedding_std(mt, subjects):
    alldata = []
    for s in subjects:
        inp = make_inputs(mt.tokenizer, [s])
        with nethook.Trace(mt.model, layername(mt.model, 0, "embed")) as t:
            mt.model(**inp)
            alldata.append(t.output[0])
    alldata = torch.cat(alldata)
    noise_level = alldata.std().item()
    return noise_level

def collect_IM_embedding_std(mt, subjects, fpath_nlv="./vit_noise_level.pt"):
    alldata = []
    print("calculate std of data")
    # for s in tqdm(subjects):
    #     image, label = s
    #     # inp = make_inputs(mt.tokenizer, [s])
    #     with nethook.Trace(mt.model, layername(mt.model, 0, "patch_embed")) as t:
    #         mt.model(**{"pixel_values": image})
    #         alldata.append(t.output[0])
    # alldata = torch.cat(alldata)
    # noise_level = alldata.std().item()

    if os.path.isfile(fpath_nlv) is False:
        # Create a small batch of random images
        batch_size = 100
        random_images = torch.randn(batch_size, 3, 224, 224).cuda()
        
        # Get patch embeddings
        with nethook.Trace(mt.model, "model.patch_embed") as t:
            mt.model(pixel_values=random_images)
            patch_embeddings = t.output
        
        # Calculate standard deviation across all dimensions
        noise_level = patch_embeddings.std().item()
        
        # Save the computed noise level
        torch.save(noise_level, fpath_nlv)
    else:
        noise_level = torch.load(fpath_nlv)
    
    return noise_level


def get_embedding_cov(mt):
    model = mt.model
    tokenizer = mt.tokenizer

    def get_ds():
        ds_name = "wikitext"
        raw_ds = load_dataset(
            ds_name,
            dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name],
        )
        try:
            maxlen = model.config.n_positions
        except:
            maxlen = 100  # Hack due to missing setting in GPT2-NeoX.
        return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)

    ds = get_ds()
    sample_size = 1000
    batch_size = 5
    filename = None
    batch_tokens = 100

    progress = lambda x, **k: x

    stat = Covariance()
    loader = tally(
        stat,
        ds,
        cache=filename,
        sample_size=sample_size,
        batch_size=batch_size,
        collate_fn=length_collation(batch_tokens),
        pin_memory=True,
        random_sample=1,
        num_workers=0,
    )
    with torch.no_grad():
        for batch_group in loader:
            for batch in batch_group:
                batch = dict_to_(batch, "cuda")
                del batch["position_ids"]
                with nethook.Trace(model, layername(mt.model, 0, "embed")) as tr:
                    model(**batch)
                feats = flatten_masked_batch(tr.output, batch["attention_mask"])
                stat.add(feats.cpu().double())
    return stat.mean(), stat.covariance()


def make_generator_transform(mean=None, cov=None):
    d = len(mean) if mean is not None else len(cov)
    device = mean.device if mean is not None else cov.device
    layer = torch.nn.Linear(d, d, dtype=torch.double)
    nethook.set_requires_grad(False, layer)
    layer.to(device)
    layer.bias[...] = 0 if mean is None else mean
    if cov is None:
        layer.weight[...] = torch.eye(d).to(device)
    else:
        _, s, v = cov.svd()
        w = s.sqrt()[None, :] * v
        layer.weight[...] = w
    return layer


def collect_embedding_gaussian(mt):
    m, c = get_embedding_cov(mt)
    return make_generator_transform(m, c)


def collect_embedding_tdist(mt, degree=3):
    # We will sample sqrt(degree / u) * sample, where u is from the chi2[degree] dist.
    # And this will give us variance is (degree / degree - 2) * cov.
    # Therefore if we want to match the sample variance, we should
    # reduce cov by a factor of (degree - 2) / degree.
    # In other words we should be sampling sqrt(degree - 2 / u) * sample.
    u_sample = torch.from_numpy(
        numpy.random.RandomState(2).chisquare(df=degree, size=1000)
    )
    fixed_sample = ((degree - 2) / u_sample).sqrt()
    mvg = collect_embedding_gaussian(mt)

    def normal_to_student(x):
        gauss = mvg(x)
        size = gauss.shape[:-1].numel()
        factor = fixed_sample[:size].reshape(gauss.shape[:-1] + (1,))
        student = factor * gauss
        return student

    return normal_to_student


if __name__ == "__main__":
    main()
