"""
This code breaks down self-repair into three components:
  - LN
  - attn heads
  - mlp layers
  
And determines how much each component contributes to self-repair
"""
# %%
from utils import KnownsDataset, LamatrexDataset, DATA_DIR
from imports import *
from collections import OrderedDict
from path_patching import act_patch
from GOOD_helpers import is_notebook, show_input, collect_direct_effect, get_single_correct_logit, topk_of_Nd_tensor, return_item, get_correct_logit_score, prepare_dataset
# %% Constants
in_notebook_mode = is_notebook()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FOLDER_TO_WRITE_GRAPHS_TO = "results/figures/breakdown_self_repair_graphs/"
FOLDER_TO_STORE_PICKLES = "pickle_storage/breakdown_self_repair/"

if in_notebook_mode:
    model_name = "pythia-160m"
    BATCH_SIZE = 2
    PERCENTILE = 0.02
    MIN_TOKENS = 1_000
else:
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', default='gpt2-xs')
    parser.add_argument('--batch_size', type=int, default=100)  
    parser.add_argument('--percentile', type=float, default=0.02)
    parser.add_argument('--min_tokens', type=int, default=100_000)
    parser.add_argument('--custom_data', type=str, required=True)
    parser.add_argument("--eot", action="store_true", default=False, help="Whether to use end of text token")
    parser.add_argument('--id', type=int, default=0)
    args = parser.parse_args()
    
    model_name = args.model_name
    BATCH_SIZE = args.batch_size
    PERCENTILE = args.percentile
    MIN_TOKENS = args.min_tokens
    safe_model_name = model_name.replace("/", "_")
    if args.model_name == "gpt2-mini":
        model_name = "erwanf/gpt2-mini"
    elif args.model_name == "gpt2-xs":
        model_name = "AlgorithmicResearchGroup/gpt2-xs"
    elif args.model_name == "pythia-1b" or args.model_name == "pythia-14m":
        model_name = "EleutherAI/" + model_name
    


# %%

from transformers import LlamaForCausalLM, LlamaTokenizer

if "llama" in model_name:
    tokenizer = LlamaTokenizer.from_pretrained(LLAMA_MODEL_PATH) 
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '<unk>'})
    
    hf_model = LlamaForCausalLM.from_pretrained(LLAMA_MODEL_PATH, low_cpu_mem_usage=True)
    
    model = HookedTransformer.from_pretrained("llama-7b", hf_model=hf_model, device="cpu", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)
    model: HookedTransformer = model.to("cuda" if torch.cuda.is_available() else "cpu") #type: ignore
elif "gpt2-mini" in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    cfg = HookedTransformerConfig(
        d_model=512,
        n_layers=4,
        n_heads=8,
        d_head=64,
        d_mlp=2048,
        n_ctx=512,
        d_vocab=50257,
        act_fn="gelu",
        normalization_type="LN",
        seed=42,
    )
    model = HookedTransformer(cfg)
    if hasattr(model.unembed, "b_U"):
        model.unembed.b_U = model.unembed.b_U = torch.nn.Parameter(torch.zeros(model.cfg.d_vocab).cuda(), requires_grad=False)
        print("b_U exists in mini")

    hf_state_dict = torch.load("../saved_model_tmp/erwanf_gpt2-mini/pytorch_model.bin")
    new_state_dict = OrderedDict()
    for name, param in hf_state_dict.items():
        if name == "transformer.wte.weight":
            new_state_dict["embed.W_E"] = param
        elif name == "transformer.wpe.weight":
            new_state_dict["pos_embed.W_pos"] = param
        elif name.startswith("transformer.h."):
            parts = name.split(".")
            layer = int(parts[2])
            subname = ".".join(parts[3:])

            # QKV concat weight: [3*d_model, d_model] = [1536, 512]
            if subname == "attn.c_attn.weight":
                W = param.T  # [1536, 512]
                W = W.reshape(3, cfg.d_model, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    W_proj = W[idx]  # [512, 512]
                    W_proj = W_proj.reshape(cfg.n_heads, cfg.d_head, cfg.d_model).permute(0, 2, 1)  # [n_heads, d_model, d_head]
                    new_state_dict[f"blocks.{layer}.attn.W_{proj}"] = W_proj.contiguous()
            elif subname == "attn.c_attn.bias":                
                b = param  # [1536]
                b = b.reshape(3, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    b_proj = b[idx]  # [512]
                    b_proj = b_proj.reshape(cfg.n_heads, cfg.d_head)
                    new_state_dict[f"blocks.{layer}.attn.b_{proj}"] = b_proj.contiguous()
            elif subname == "attn.c_proj.weight":
                W_O = param.T  # [512, 512]
                W_O = W_O.reshape(cfg.d_model, cfg.n_heads, cfg.d_head).permute(1, 2, 0)  # [n_heads, d_head, d_model]
                new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O.contiguous()
            elif subname == "attn.c_proj.bias":
                new_state_dict[f"blocks.{layer}.attn.b_O"] = param
            elif subname == "ln_1.weight":
                new_state_dict[f"blocks.{layer}.ln1.w"] = param
            elif subname == "ln_1.bias":
                new_state_dict[f"blocks.{layer}.ln1.b"] = param
            elif subname == "ln_2.weight":
                new_state_dict[f"blocks.{layer}.ln2.w"] = param
            elif subname == "ln_2.bias":
                new_state_dict[f"blocks.{layer}.ln2.b"] = param
            elif subname == "mlp.c_fc.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_in"] = param
            elif subname == "mlp.c_fc.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_in"] = param
            elif subname == "mlp.c_proj.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_out"] = param
            elif subname == "mlp.c_proj.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_out"] = param
        elif name == "transformer.ln_f.weight":
            new_state_dict["ln_final.w"] = param
        elif name == "transformer.ln_f.bias":
            new_state_dict["ln_final.b"] = param
        elif name == "lm_head.weight":
            # HookedTransformer expects [d_model, vocab] → transpose
            new_state_dict["unembed.W_U"] = param.T
            
    missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
    print("Missing keys:", missing)
    print("Unexpected keys:", unexpected)
    model.tokenizer = tokenizer
    
    model.eval()
    custom_data_path = None
    if args.custom_data is not None:
        custom_data_path = "somepath"
        raise NotImplementedError("Custom data path not implemented")
    target_data_idx = sorted([int(i.split("R")[-1]) for i in os.listdir(os.path.join(custom_data_path, "results"))])
    target_data_zip = {"inp":[], "y":[]}
    for i in target_data_idx:
        bufffer = []
        for line in open(os.path.join(custom_data_path, "inp_info", "I{:06d}.txt".format(i)), "r", encoding="utf-8"):
            bufffer.append(line)
        # 
        inp_text = bufffer[1].split("prompt:")[-1].split("\n")[0]
        inp_text = "<|endoftext|>"+inp_text
        target_data_zip["inp"].append(inp_text)
        target_data_zip["y"].append(bufffer[2].split("y:")[-1].split("\n")[0])
        
        cache = {}
        def hook_fn(value, hook):
            cache[hook.name] = value
    
        debug_inp = bufffer[1].split("prompt:")[-1].split("\n")[0]
        debug_inp = model.tokenizer(debug_inp)
        debug_inp = torch.tensor(debug_inp['input_ids']).unsqueeze(0)
        
        hook_list = [("blocks.0.hook_resid_pre", hook_fn)]
        for l in range(model.cfg.n_layers):
            hook_list += [
                (f"blocks.{l}.hook_resid_post", hook_fn),
                (f"blocks.{l}.hook_attn_out", hook_fn),
                (f"blocks.{l}.hook_mlp_out", hook_fn),
            ]
        with model.hooks(fwd_hooks=hook_list):
            debug_out = model(debug_inp)
        debug_out_sftmx = torch.softmax(debug_out[:, -1, :], dim=1)[0]
        debug_answer = torch.max(debug_out_sftmx, dim=0).indices.unsqueeze(0)
        print("[{}] {} : {}".format(i, debug_answer.item(), debug_out_sftmx[debug_answer].item()))
    target_idx=2
    target_inp = torch.tensor(model.tokenizer(target_data_zip['inp'][target_idx])['input_ids']).unsqueeze(0)
elif "gpt2-xs" in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    cfg = HookedTransformerConfig(
        d_model = 384,
        n_layers = 6,
        n_heads = 6,
        d_head = 64,
        d_mlp = 4*384, # defaults to 4 * d_model
        n_ctx = 256, # cofig.n_positions
        d_vocab = 50257,
        act_fn = "gelu",
        normalization_type="LN",
        seed = 42,
    )
    model = HookedTransformer(cfg)
    if hasattr(model.unembed, "b_U"):
        model.unembed.b_U = model.unembed.b_U = torch.nn.Parameter(torch.zeros(model.cfg.d_vocab).cuda(), requires_grad=False)
        # print("b_U exists in xs")
    
    hf_state_dict = torch.load("gpt2-xs/pytorch_model.bin")
    new_state_dict = OrderedDict()
    for name, param in hf_state_dict.items():
        # embeddings
        if name == "transformer.wte.weight":
            new_state_dict["embed.W_E"] = param
        elif name == "transformer.wpe.weight":
            new_state_dict["pos_embed.W_pos"] = param

        # transformer blocks
        elif name.startswith("transformer.h."):
            parts = name.split(".")
            layer = int(parts[2])
            subname = ".".join(parts[3:])

            # (a) QKV concat weight: [3*d_model, d_model] = [3*384, 384]
            if subname == "attn.c_attn.weight":
                W = param.T #[3*384, 384]
                W = W.reshape(3, cfg.d_model, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    W_proj = W[idx] #[384, 384]
                    W_proj = W_proj.reshape(cfg.n_heads, cfg.d_head, cfg.d_model).permute(0, 2, 1) #[n_heads, d_model, d_head]
                    new_state_dict[f"blocks.{layer}.attn.W_{proj}"] = W_proj.contiguous()
            elif subname == "attn.c_attn.bias":
                b = param
                b = b.reshape(3, cfg.d_model)
                for proj, idx in zip(["Q", "K", "V"], [0, 1, 2]):
                    b_proj = b[idx]
                    b_proj = b[idx].reshape(cfg.n_heads, cfg.d_head)      # [6, 64]
                    new_state_dict[f"blocks.{layer}.attn.b_{proj}"] = b_proj.contiguous()

            # (b) Attention output projection
            elif subname == "attn.c_proj.weight":
                W_O = (
                    param.T.reshape(cfg.d_model, cfg.n_heads, cfg.d_head)
                    .permute(1, 2, 0)
                )                                                        # [6, 64, 384]
                new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O.contiguous()
            elif subname == "attn.c_proj.bias":
                new_state_dict[f"blocks.{layer}.attn.b_O"] = param

            # (c) LayerNorms
            elif subname == "ln_1.weight":
                new_state_dict[f"blocks.{layer}.ln1.w"] = param
            elif subname == "ln_1.bias":
                new_state_dict[f"blocks.{layer}.ln1.b"] = param
            elif subname == "ln_2.weight":
                new_state_dict[f"blocks.{layer}.ln2.w"] = param
            elif subname == "ln_2.bias":
                new_state_dict[f"blocks.{layer}.ln2.b"] = param

            # (d) MLP
            elif subname == "mlp.c_fc.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_in"] = param
            elif subname == "mlp.c_fc.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_in"] = param
            elif subname == "mlp.c_proj.weight":
                new_state_dict[f"blocks.{layer}.mlp.W_out"] = param
            elif subname == "mlp.c_proj.bias":
                new_state_dict[f"blocks.{layer}.mlp.b_out"] = param

        # --- final LayerNorm & unembed ------------------------------------
        elif name == "transformer.ln_f.weight":
            new_state_dict["ln_final.w"] = param
        elif name == "transformer.ln_f.bias":
            new_state_dict["ln_final.b"] = param
        elif name == "lm_head.weight":
            new_state_dict["unembed.W_U"] = param.T
        else:
            raise ValueError(f"Unexpected key in state dict: {name}")
    missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
    model.tokenizer = tokenizer
    model = model.to(device).eval()

else:
    model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed = True,
        center_writing_weights = True,
        fold_ln = True, # TODO; understand this
        refactor_factored_attn_matrices = False,
        device = device,
    )
    tokenizer = model.tokenizer
model.set_use_attn_result(True)
print("Model loaded Done")


dataset_name = args.custom_data
# all_dataset_tokens = model.to_tokens(dataset["text"]).to(device)

def ablate_top_instances_and_get_breakdown(head: tuple, clean_tokens: Tensor, corrupted_tokens: Tensor,
                                           per_head_direct_effect: Union[Tensor, None] = None,
                                           all_layer_direct_effect: Union[Tensor, None] = None,
                                           cache: Union[ActivationCache, None] = None,
                                           logits: Union[Tensor, None] = None,):
    if cache is None or logits is None:
        print("Cache not provided")
        cache, logits = model.run_with_cache(clean_tokens)
        assert cache != None
        assert logits != None
        
    if per_head_direct_effect is None or all_layer_direct_effect is None:
        print("Per head direct effect not provided")
        per_head_direct_effect, all_layer_direct_effect = collect_direct_effect(cache, correct_tokens=clean_tokens, model = model, display = False, collect_individual_neurons = False, cache_for_scaling = cache)
        assert per_head_direct_effect != None
        assert all_layer_direct_effect != None
    
    # Run ablation and get new cache/logit/DEs
    ablated_cache: ActivationCache = act_patch(model, clean_tokens, [Node("z", head[0], head[1])], return_item, corrupted_tokens, apply_metric_to_cache = True) #type: ignore
    ablated_logits = act_patch(model, clean_tokens, [Node("z", head[0], head[1])], return_item, corrupted_tokens, apply_metric_to_cache = False)
    ablated_per_head_direct_effect, ablated_all_layer_direct_effect = collect_direct_effect(ablated_cache, correct_tokens=clean_tokens, model = model, display = False, collect_individual_neurons = False, cache_for_scaling = cache)

    # get the logit difference between everything
    correct_logits = get_correct_logit_score(logits, clean_tokens)
    ablated_logits = get_correct_logit_score(ablated_logits, clean_tokens)
    logit_diffs = ablated_logits - correct_logits
    
    # Get Direct Effect of Heads Pre/Post-Ablation
    direct_effects = per_head_direct_effect[head[0],head[1]] 
    ablated_direct_effects = ablated_per_head_direct_effect[head[0],head[1]]
    
    
    # Calculate self-repair values
    self_repair = logit_diffs - (ablated_direct_effects - direct_effects)
    self_repair_from_heads = (ablated_per_head_direct_effect - per_head_direct_effect).sum((0,1)) - (ablated_per_head_direct_effect - per_head_direct_effect)[head[0],head[1]]
    self_repair_from_layers = (ablated_all_layer_direct_effect - all_layer_direct_effect).sum(0)
    self_repair_from_LN = self_repair - self_repair_from_heads - self_repair_from_layers   # "self repair from LN" is the residual
    
    return logit_diffs, direct_effects, ablated_direct_effects, self_repair_from_heads, self_repair_from_layers, self_repair_from_LN, self_repair

CPT_ROOT = Path("./path/to/your/cpt/root")

# We need to iterate through the dataset to find the 
PROMPT_LEN = 128
TOTAL_TOKENS = ((MIN_TOKENS // (PROMPT_LEN * BATCH_SIZE)) + 1) * (PROMPT_LEN * BATCH_SIZE)

dataset, num_batches = prepare_dataset(model, device, TOTAL_TOKENS, BATCH_SIZE, PROMPT_LEN, False, dataset_name="pile")

if args.custom_data is None:
    TOTAL_PROMPTS_TO_ITERATE_THROUGH = num_batches * BATCH_SIZE
else:
    BATCH_SIZE = args.batch_size
    TOTAL_PROMPTS_TO_ITERATE_THROUGH = args.batch_size
    PROMPT_LEN = 10
    

if args.custom_data is None:
    raise NotImplementedError("Custom data not implemented")
elif args.custom_data == "knowns1000":
    data = KnownsDataset(DATA_DIR)
elif args.custom_data == "lama":
    data = LamatrexDataset(DATA_DIR)
else:
    raise NotImplementedError("Custom data not implemented")

example_ids = []
dir_name = "20250505_0612_debug_" + args.model_name + ("_lama_trex" if args.custom_data == "lama" else "")  + "_exStWd"
path = CPT_ROOT / dir_name / "results" 
files = os.listdir(path)
# print(files)
for file in files:
    if file.startswith("R"):
        example_ids.append(int(file.split("R")[-1].split(".")[0]))
# print(example_ids)
for batch_idx, clean_tokens, corrupted_tokens in dataset:

    if batch_idx!=0:
            continue

    for id in tqdm(example_ids): 
        # file exists check
        FOLDER_TO_STORE_SCORE_RAW = Path("./self_repair_mean")
        if "AlgorithmicResearchGroup" in model_name:
            safe_model_name = safe_model_name.split("AlgorithmicResearchGroup")[-1]
            safe_model_name = "gpt2-xs"
        eot = "eot" if args.eot else "no_eot"
        path = FOLDER_TO_STORE_SCORE_RAW /eot /f"{safe_model_name}" / f"{args.custom_data}" / f"{id}"

        # if results already exist, skip
        if os.path.exists(path / f"self_repair_across_everything.pt"):
            # print(f"File {path / f'self_repair_across_everything.pt'} already exists, skipping save.")
            continue

        datum = data[id]
        if args.eot:
            prompt_text = "<|endoftext|>"+datum["prompt"]+" "+datum["attribute"]             # the question + answer
        else:
            prompt_text = datum["prompt"]+" "+datum["attribute"]             # the question + answer
        tok = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) # [1, T]

        PROMPT_LEN = tok.shape[1]
        assert PROMPT_LEN < 128
        clean_tokens = tok.repeat(BATCH_SIZE, 1)
        
        if args.eot:
            corrupted_token = corrupted_tokens[:BATCH_SIZE, :PROMPT_LEN]
        else:
            corrupted_token = corrupted_tokens[:BATCH_SIZE, 1:PROMPT_LEN+1]
        
        assert clean_tokens.shape == corrupted_token.shape == (BATCH_SIZE, PROMPT_LEN)

        start_clean_prompt = batch_idx * BATCH_SIZE
        end_clean_prompt = start_clean_prompt + BATCH_SIZE

        # Cache clean/corrupted model activations + direct effects
        logits, cache = model.run_with_cache(clean_tokens)
        per_head_direct_effect, all_layer_direct_effect = collect_direct_effect(cache, correct_tokens=clean_tokens, model = model, display = False, collect_individual_neurons = False)
        # save only the mean values in the last token
        self_repair_across_everything = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
        self_repair_from_heads_across_everything = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
            
        for layer in range(model.cfg.n_layers):
            for head in range(model.cfg.n_heads):

                logit_diffs, direct_effects, ablated_direct_effects, self_repair_from_heads, self_repair_from_layers, self_repair_from_LN, self_repair = ablate_top_instances_and_get_breakdown((layer, head), clean_tokens, corrupted_token, per_head_direct_effect, all_layer_direct_effect, cache, logits)
                
                self_repair_from_heads = self_repair_from_heads[:, -1].mean(dim=0) # [BATCH_SIZE, PROMPT_LEN - 1] -> one value
                self_repair_from_heads_across_everything[layer, head] = self_repair_from_heads # single
                
                self_repair = self_repair[:, -1,].mean(dim=0) # [BATCH_SIZE, PROMPT_LEN - 1] -> one value
                self_repair_across_everything[layer, head] = self_repair

        tensors_to_save = {
            "self_repair_across_everything": self_repair_across_everything,
            "self_repair_from_heads_across_everything": self_repair_from_heads_across_everything,
        }
        
        for tensor_name, tensor_data in tensors_to_save.items():
            os.makedirs(path, exist_ok=True)
            torch.save(tensor_data, path / f"{tensor_name}.pt")
