"""
This code breaks down self-repair into three components:
  - LN
  - attn heads
  - mlp layers
  
And determines how much each component contributes to self-repair
"""
# %%
import sys
from pathlib import Path
vit_path = Path("path/to/CausalPathTracing_for_ViT/main")
if vit_path not in sys.path:
    sys.path.insert(0, str(vit_path))
vit_prisma_path = Path("path/to/cpt/ViT-Prisma/src") # https://github.com/Prisma-Multimodal/ViT-Prisma.git
if vit_prisma_path not in sys.path:
    sys.path.insert(0, str(vit_prisma_path))
from vit_prisma.models.base_vit import HookedViT

from lib.utils import get_model, get_data
from utils import KnownsDataset, LamatrexDataset, DATA_DIR
from imports import *
from collections import OrderedDict
from path_patching import act_patch
from ViT_helpers import is_notebook, show_input, collect_direct_effect, get_single_correct_logit, topk_of_Nd_tensor, return_item, get_correct_logit_score, prepare_dataset
# %% Constants
in_notebook_mode = is_notebook()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLDER_TO_WRITE_GRAPHS_TO = "results/figures/breakdown_self_repair_graphs/"
FOLDER_TO_STORE_PICKLES = "pickle_storage/breakdown_self_repair/"

if in_notebook_mode:
    model_name = "pythia-160m"
    BATCH_SIZE = 2
    PERCENTILE = 0.02
    MIN_TOKENS = 1_000
else:
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', default='gpt2-xs')
    parser.add_argument('--batch_size', type=int, default=100)  
    parser.add_argument('--percentile', type=float, default=0.02)
    parser.add_argument('--min_tokens', type=int, default=100_000)
    parser.add_argument('--custom_data', type=str, required=True)
    parser.add_argument('--id', type=int, default=0)
    args = parser.parse_args()
    
    model_name = args.model_name
    BATCH_SIZE = args.batch_size
    PERCENTILE = args.percentile
    MIN_TOKENS = args.min_tokens
    safe_model_name = model_name.replace("/", "_")
    if args.model_name == "gpt2-mini":
        model_name = "erwanf/gpt2-mini"
    elif args.model_name == "gpt2-xs":
        model_name = "AlgorithmicResearchGroup/gpt2-xs"
    elif args.model_name == "pythia-1b" or args.model_name == "pythia-14m":
        model_name = "EleutherAI/" + model_name
    


from transformers import LlamaForCausalLM, LlamaTokenizer

if "llama" in model_name:
    tokenizer = LlamaTokenizer.from_pretrained(LLAMA_MODEL_PATH) 
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '<unk>'})
    
    hf_model = LlamaForCausalLM.from_pretrained(LLAMA_MODEL_PATH, low_cpu_mem_usage=True)
    
    model = HookedTransformer.from_pretrained("llama-7b", hf_model=hf_model, device="cpu", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)
    model: HookedTransformer = model.to("cuda" if torch.cuda.is_available() else "cpu") #type: ignore
elif "gpt2-mini" in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    cfg = HookedTransformerConfig(
        d_model=512,
        n_layers=4,
        n_heads=8,
        d_head=64,
        d_mlp=2048,
        n_ctx=512,
        d_vocab=50257,
        act_fn="gelu",
        normalization_type="LN",
        seed=42,
    )
    model = HookedTransformer(cfg)
    if hasattr(model.unembed, "b_U"):
        model.unembed.b_U = model.unembed.b_U = torch.nn.Parameter(torch.zeros(model.cfg.d_vocab).cuda(), requires_grad=False)
        print("b_U exists in mini")

    hf_state_dict = torch.load("../saved_model_tmp/erwanf_gpt2-mini/pytorch_model.bin")
    new_state_dict = OrderedDict()
    for name, param in hf_state_dict.items():
        if name == "transformer.wte.weight":
            new_state_dict["embed.W_E"] = param
        elif name == "transformer.wpe.weight":
            new_state_dict["pos_embed.W_pos"] = param
        elif name.startswith("transformer.h."):
            parts = name.split(".")
            layer = int(parts[2])
            subname = ".".join(parts[3:])

            # QKV concat weight: [3*d_model, d_model] = [1536, 512]
            if subname == "attn.c_attn.weight":
                W = param.T  # [1536, 512]
                W = W.reshape(3, cfg.d_model, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    W_proj = W[idx]  # [512, 512]
                    W_proj = W_proj.reshape(cfg.n_heads, cfg.d_head, cfg.d_model).permute(0, 2, 1)  # [n_heads, d_model, d_head]
                    new_state_dict[f"blocks.{layer}.attn.W_{proj}"] = W_proj.contiguous()
            elif subname == "attn.c_attn.bias":                
                b = param  # [1536]
                b = b.reshape(3, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    b_proj = b[idx]  # [512]
                    b_proj = b_proj.reshape(cfg.n_heads, cfg.d_head)
                    new_state_dict[f"blocks.{layer}.attn.b_{proj}"] = b_proj.contiguous()
            elif subname == "attn.c_proj.weight":
                W_O = param.T  # [512, 512]
                W_O = W_O.reshape(cfg.d_model, cfg.n_heads, cfg.d_head).permute(1, 2, 0)  # [n_heads, d_head, d_model]
                new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O.contiguous()
            elif subname == "attn.c_proj.bias":
                new_state_dict[f"blocks.{layer}.attn.b_O"] = param
            elif subname == "ln_1.weight":
                new_state_dict[f"blocks.{layer}.ln1.w"] = param
            elif subname == "ln_1.bias":
                new_state_dict[f"blocks.{layer}.ln1.b"] = param
            elif subname == "ln_2.weight":
                new_state_dict[f"blocks.{layer}.ln2.w"] = param
            elif subname == "ln_2.bias":
                new_state_dict[f"blocks.{layer}.ln2.b"] = param
            elif subname == "mlp.c_fc.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_in"] = param
            elif subname == "mlp.c_fc.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_in"] = param
            elif subname == "mlp.c_proj.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_out"] = param
            elif subname == "mlp.c_proj.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_out"] = param
        elif name == "transformer.ln_f.weight":
            new_state_dict["ln_final.w"] = param
        elif name == "transformer.ln_f.bias":
            new_state_dict["ln_final.b"] = param
        elif name == "lm_head.weight":
            # HookedTransformer expects [d_model, vocab] → transpose
            new_state_dict["unembed.W_U"] = param.T
            
    missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
    print("Missing keys:", missing)
    print("Unexpected keys:", unexpected)
    model.tokenizer = tokenizer
    
elif "gpt2-xs" in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    cfg = HookedTransformerConfig(
        d_model = 384,
        n_layers = 6,
        n_heads = 6,
        d_head = 64,
        d_mlp = 4*384, # defaults to 4 * d_model
        n_ctx = 256, # cofig.n_positions
        d_vocab = 50257,
        act_fn = "gelu",
        normalization_type="LN",
        seed = 42,
    )
    model = HookedTransformer(cfg)
    if hasattr(model.unembed, "b_U"):
        model.unembed.b_U = model.unembed.b_U = torch.nn.Parameter(torch.zeros(model.cfg.d_vocab).cuda(), requires_grad=False)
        # print("b_U exists in xs")
    
    hf_state_dict = torch.load("../saved_model_tmp/gpt2-xs/pytorch_model.bin")
    new_state_dict = OrderedDict()
    for name, param in hf_state_dict.items():
        # embeddings
        if name == "transformer.wte.weight":
            new_state_dict["embed.W_E"] = param
        elif name == "transformer.wpe.weight":
            new_state_dict["pos_embed.W_pos"] = param

        # transformer blocks
        elif name.startswith("transformer.h."):
            parts = name.split(".")
            layer = int(parts[2])
            subname = ".".join(parts[3:])

            # (a) QKV concat weight: [3*d_model, d_model] = [3*384, 384]
            if subname == "attn.c_attn.weight":
                W = param.T #[3*384, 384]
                W = W.reshape(3, cfg.d_model, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    W_proj = W[idx] #[384, 384]
                    W_proj = W_proj.reshape(cfg.n_heads, cfg.d_head, cfg.d_model).permute(0, 2, 1) #[n_heads, d_model, d_head]
                    new_state_dict[f"blocks.{layer}.attn.W_{proj}"] = W_proj.contiguous()
            elif subname == "attn.c_attn.bias":
                b = param
                b = b.reshape(3, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    b_proj = b[idx]
                    b_proj = b[idx].reshape(cfg.n_heads, cfg.d_head)      # [6, 64]
                    new_state_dict[f"blocks.{layer}.attn.b_{proj}"] = b_proj.contiguous()

            # (b) Attention output projection
            elif subname == "attn.c_proj.weight":
                W_O = (
                    param.T.reshape(cfg.d_model, cfg.n_heads, cfg.d_head)
                    .permute(1, 2, 0)
                )                                                        # [6, 64, 384]
                new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O.contiguous()
            elif subname == "attn.c_proj.bias":
                new_state_dict[f"blocks.{layer}.attn.b_O"] = param

            # (c) LayerNorms
            elif subname == "ln_1.weight":
                new_state_dict[f"blocks.{layer}.ln1.w"] = param
            elif subname == "ln_1.bias":
                new_state_dict[f"blocks.{layer}.ln1.b"] = param
            elif subname == "ln_2.weight":
                new_state_dict[f"blocks.{layer}.ln2.w"] = param
            elif subname == "ln_2.bias":
                new_state_dict[f"blocks.{layer}.ln2.b"] = param

            # (d) MLP
            elif subname == "mlp.c_fc.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_in"] = param
            elif subname == "mlp.c_fc.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_in"] = param
            elif subname == "mlp.c_proj.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_out"] = param
            elif subname == "mlp.c_proj.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_out"] = param

        # --- final LayerNorm & unembed ------------------------------------
        elif name == "transformer.ln_f.weight":
            new_state_dict["ln_final.w"] = param
        elif name == "transformer.ln_f.bias":
            new_state_dict["ln_final.b"] = param
        elif name == "lm_head.weight":
            new_state_dict["unembed.W_U"] = param.T
        else:
            raise ValueError(f"Unexpected key in state dict: {name}")
    missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
    # print("Missing keys:", missing)
    # print("Unexpected keys:", unexpected)
    model.tokenizer = tokenizer

    model = model.to(device).eval()

elif "vit_tiny_patch16_224" == model_name or "deit_tiny_patch16_224" == model_name:
    model = HookedViT.from_pretrained(model_name,
                center_writing_weights=True,
                center_unembed=True,
                fold_ln=True,
                refactor_factored_attn_matrices=True,
            )
    
    model.cfg.use_attn_result = True
    model.cfg.use_attn_in = True
    model.cfg.use_hook_mlp_in = True

    model.cfg.use_split_qkv_input = False
    model.cfg.tokenizer_prepends_bos = False
    model.cfg.default_prepend_bos = False

    model.cfg.return_type = "logits"

    model.to(device)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    tokenizer = None
else:
    model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed = True,
        center_writing_weights = True,
        fold_ln = True, # TODO; understand this
        refactor_factored_attn_matrices = False,
        device = device,
    )
    tokenizer = model.tokenizer
model.set_use_attn_result(True)
print("Model loaded Done")

dataset_name = args.custom_data
# %% 

def ablate_top_instances_and_get_breakdown(head: tuple, clean_inps: Tensor, corrupted_tokens: Tensor,
                                            correct_labels: Tensor,
                                           per_head_direct_effect: Union[Tensor, None] = None,
                                           all_layer_direct_effect: Union[Tensor, None] = None,
                                           cache: Union[ActivationCache, None] = None,
                                           logits: Union[Tensor, None] = None,
                                           ):
    # print("shape of clean tokens : ", clean_inps.shape)    
    # print("shape of corrupted tokens : ", corrupted_tokens.shape)

    
    if cache is None or logits is None:
        print("Cache not provided")
        cache, logits = model.run_with_cache(clean_inps)
        assert cache != None
        assert logits != None
        
    if per_head_direct_effect is None or all_layer_direct_effect is None:
        print("Per head direct effect not provided")
        per_head_direct_effect, all_layer_direct_effect = collect_direct_effect(cache, correct_labels=correct_labels, model = model, display = False, collect_individual_neurons = False, cache_for_scaling = cache)
        assert per_head_direct_effect != None
        assert all_layer_direct_effect != None
    
    # Run ablation and get new cache/logit/DEs
    # ablated_cache: ActivationCache = act_patch(model, clean_inps, [Node("z", head[0], head[1])], return_item, corrupted_tokens, apply_metric_to_cache = True) #type: ignore
    ablated_cache: ActivationCache = act_patch(model, clean_inps, [Node("z", head[0], head[1])], return_item, corrupted_tokens, apply_metric_to_cache = True) #type: ignore
    ablated_logits = act_patch(model, clean_inps, [Node("z", head[0], head[1])], return_item, corrupted_tokens, apply_metric_to_cache = False)
    ablated_per_head_direct_effect, ablated_all_layer_direct_effect = collect_direct_effect(ablated_cache, correct_labels=correct_labels, model = model, display = False, collect_individual_neurons = False, cache_for_scaling = cache)

    # get the logit difference between everything
    correct_logits = get_correct_logit_score(logits, correct_labels)
    ablated_logits = get_correct_logit_score(ablated_logits, correct_labels)
    logit_diffs = ablated_logits - correct_logits
    
    # Get Direct Effect of Heads Pre/Post-Ablation
    direct_effects = per_head_direct_effect[head[0],head[1]] 
    ablated_direct_effects = ablated_per_head_direct_effect[head[0],head[1]]
    
    
    # Calculate self-repair values
    self_repair = logit_diffs - (ablated_direct_effects - direct_effects)
    delta = ablated_per_head_direct_effect - per_head_direct_effect #[L, H, B]
    self_repair_from_heads = delta.sum((0,1)) - delta[head[0],head[1]]
    self_repair_from_layers = (ablated_all_layer_direct_effect - all_layer_direct_effect).sum(0)
    self_repair_from_LN = self_repair - self_repair_from_heads - self_repair_from_layers   # "self repair from LN" is the residual
    
    return logit_diffs, direct_effects, ablated_direct_effects, self_repair_from_heads, self_repair_from_layers, self_repair_from_LN, self_repair



PROMPT_LEN = 128
TOTAL_TOKENS = ((MIN_TOKENS // (PROMPT_LEN * BATCH_SIZE)) + 1) * (PROMPT_LEN * BATCH_SIZE)


if args.custom_data is None:
    pass
else:
    BATCH_SIZE = args.batch_size
    TOTAL_PROMPTS_TO_ITERATE_THROUGH = args.batch_size
    PROMPT_LEN = 10
    

if args.custom_data is None:
    raise NotImplementedError("Custom data not implemented")
elif args.custom_data == "knowns1000":
    data = KnownsDataset(DATA_DIR)
elif args.custom_data == "lama":
    data = LamatrexDataset(DATA_DIR)
elif args.custom_data == "imagenet" or args.custom_data == "officehome":
    data, _ = get_data(args.custom_data)
    print("data loaded")
else:
    raise NotImplementedError("Custom data not implemented")

CPT_ROOT = Path('/data8/baek/cpt/EVALUATION/main/results')
KNOWNS_ROOT = Path('/data8/baek/cpt/EVALUATION/main/results/ROME2CPT/all_corrupted/wo_all/top_1/')
def get_corrupted(data_name):
    dataset, _ = get_data(data_name)
    dataloader_for_corrupt = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # Get a batch of random images from the loader
    corrupted_images = None
    for batch_images, _ in dataloader_for_corrupt:
        # Create corrupted images by replacing random patches
        patch_size = 16
        B, C, H, W = batch_images.shape
        num_patches_h = H // patch_size
        num_patches_w = W // patch_size
        total_patches = num_patches_h * num_patches_w
        
        # Extract all patches from all images
        patches = batch_images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        patches = patches.contiguous().view(B, C, -1, patch_size, patch_size)  # [B, C, num_patches, 16, 16]
        
        # Create random indices for shuffling
        shuffle_indices = torch.randperm(B * total_patches) 
        
        # Create new images with shuffled patches
        corrupted_images = torch.zeros_like(batch_images)
        for i in range(B):
            for p in range(total_patches):
                # Calculate source batch and patch indices
                shuffle_idx = shuffle_indices[i * total_patches + p]
                src_batch_idx = shuffle_idx // total_patches
                src_patch_idx = shuffle_idx % total_patches
                
                # Calculate patch positions
                h_idx = (p // num_patches_w) * patch_size
                w_idx = (p % num_patches_w) * patch_size
                
                # Place shuffled patch in new position
                corrupted_images[i, :, h_idx:h_idx+patch_size, w_idx:w_idx+patch_size] = \
                    patches[src_batch_idx, :, src_patch_idx]
        
        corrupted_images = corrupted_images.cuda()
        break  # Just take the first batch
    if corrupted_images is None:
        raise ValueError("Could not load images from the loader")
    return corrupted_images

corrupted_images = get_corrupted(args.custom_data).to(device)
corrupted_inps = {"pixel_values": corrupted_images}
corrupted_inps = corrupted_images


path = KNOWNS_ROOT / model_name / args.custom_data / "results"
files = os.listdir(path)
example_ids = []
for file in files:
    if file.startswith("R"):
        example_ids.append(int(file.split("R")[-1].split(".")[0]))


for id in tqdm(example_ids): 
    # file exists check
    FOLDER_TO_STORE_SCORE_RAW = Path("./self_repair_mean")
    if "AlgorithmicResearchGroup" in model_name:
        safe_model_name = safe_model_name.split("AlgorithmicResearchGroup")[-1]
        safe_model_name = "gpt2-xs"
    path = FOLDER_TO_STORE_SCORE_RAW / "ViT" /f"{safe_model_name}" / f"{args.custom_data}" / f"{id}"
    
    if os.path.exists(path / f"self_repair_across_everything.pt"):
        # print(f"File {path / f'self_repair_across_everything.pt'} already exists, skipping save.")
        continue    
    datum = data[id]
    image, label = datum
    image.cuda()
    # labels (torch.Tensor): A 1D tensor of label indices with shape (batch_size,).
    labels = torch.full((BATCH_SIZE, ), label, device=device)
    
    clean_images = image.unsqueeze(0).repeat(BATCH_SIZE, 1, 1, 1).to(device) # [B, C, H, W]
    clean_inps = {"pixel_values": clean_images}
    clean_inps = clean_images

    try:
        assert clean_images.shape == corrupted_images.shape
    except:
        import pdb; pdb.set_trace()

    # Cache clean/corrupted model activations + direct effects
    logits, cache = model.run_with_cache(clean_images) #logits : [Batch, number of labels]

    per_head_direct_effect, all_layer_direct_effect = collect_direct_effect(cache, correct_labels=labels, model = model, display = False, collect_individual_neurons = False)
    self_repair_across_everything = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
    self_repair_from_heads_across_everything = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
        
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            logit_diffs, direct_effects, ablated_direct_effects, self_repair_from_heads, self_repair_from_layers, self_repair_from_LN, self_repair = ablate_top_instances_and_get_breakdown((layer, head), clean_inps, corrupted_inps, labels, per_head_direct_effect, all_layer_direct_effect, cache, logits)
            
            self_repair_from_heads = self_repair_from_heads.mean(dim=0) #  # (B,) → scalar

            self_repair_from_heads_across_everything[layer, head] = self_repair_from_heads # single
            
            self_repair            = self_repair.mean(dim=0)
            self_repair_across_everything[layer, head] = self_repair
            

    tensors_to_save = {
        "self_repair_across_everything": self_repair_across_everything,
        "self_repair_from_heads_across_everything": self_repair_from_heads_across_everything,
    }
    for tensor_name, tensor_data in tensors_to_save.items():
        os.makedirs(path, exist_ok=True)
        torch.save(tensor_data, path / f"{tensor_name}.pt")
        # print(f"Saved {tensor_name} to {path / f'{tensor_name}.pt'}, shape: {tensor_data.shape}")
        
        