

# /root/yingtao/neuro_glia_network/trained_model/step_20000_lr_0.001
from transformers import LlamaForCausalLM as HF_LlamaForCausalLM
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft_pretraining.modeling_llama import LlamaForCausalLM
import numpy as np
from collections import defaultdict
import argparse
from utils import build_neuro_glia_network


# model_config = AutoConfig.from_pretrained("/root/yingtao/neuro_glia_network/configs/llama_130m.json")

# model = AutoModelForCausalLM.from_config(model_config)
# model_name = "/amax/yingtao/wanda-main/llm_weights/llama-2-7b-hf"
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.float16,
#     device_map="auto"
# )
model_config = AutoConfig.from_pretrained("configs/llama_130m.json")

model = HF_LlamaForCausalLM(model_config)


model_name = "/data/yt/neuro_glia_network/checkpoints/all_model.pt"
state_dict = torch.load(model_name, map_location="cuda")
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    new_key = k.replace("module.", "") 
    new_state_dict[new_key] = v
new_state_dict = {k: v for k, v in new_state_dict.items() if "rotary_emb.inv_freq" not in k}
args = argparse.Namespace()
args.nonlinear_function = "learnable_sigmoid"
args.channel_wise = True
args.scalar_wise = True
args.use_input_cmgllf = False
args.use_output_cmgllf = False
args.not_contextual = False
args.hidden_size = 8
model = build_neuro_glia_network(model, args)
print(model)
model.load_state_dict(new_state_dict)
# exit()


# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
#     new_key = k.replace("module.", "") 
#     new_state_dict[new_key] = v
# new_state_dict = {k: v for k, v in new_state_dict.items() if "rotary_emb.inv_freq" not in k}
# print(model)
# model.load_state_dict(new_state_dict)
# print(model.model.layers[-1].mlp.up_proj.mlp_modulator[2].weight)
# exit()


# model_name = "/root/yingtao/neuro_glia_network/trained_model/step_20000_lr_0.001"
tokenizer = AutoTokenizer.from_pretrained("/data/yt/neuro_glia_network/trained_model/step_10000_lr_0.003", use_auth_token=True)
# tokenizer.pad_token = tokenizer.eos_token
# while True:
#     try:
#         tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=args.max_length)
#         break
#     except:
#         import time
#         print("Failed to load tokenizer, retrying...")
#         # logger.warning("Failed to load tokenizer, retrying...")
#         time.sleep(1)
model.eval()
# exit()

# results = defaultdict(list)  # Store per-layer results

# # This assumes each block has input X and output Y before residual addition.
# def register_hooks(model):
#     for layer_idx, block in enumerate(model.model.layers):
#         def hook_fn(module, input, output, idx=layer_idx):
#             input_tmp, output_tmp = input[0].clone().cpu(), output[0].clone().cpu()
            
#             x = input_tmp  # residual input
#             y = output_tmp - x  # functional output (before modulator if needed)

#             # Compute L2 norm ratio
#             norm_x = torch.norm(x, dim=-1)  # (batch, seq)
#             norm_y = torch.norm(y, dim=-1)
#             norm_ratio = norm_y / (norm_x + 1e-6)

#             # Cosine similarity
#             cosine_sim = F.cosine_similarity(x, y, dim=-1)  # (batch, seq)

#             # Save metrics
#             results[f"layer_{idx}_norm_ratio"].append(norm_ratio.detach().cpu())
#             results[f"layer_{idx}_cosine"].append(cosine_sim.detach().cpu())

#             # Optionally, store norm_y_mod / norm_y_unmod if you added a modulator
#             if hasattr(module, 'modulator_output') and hasattr(module, 'raw_output'):
#                 y_mod = module.modulator_output
#                 y_raw = module.raw_output
#                 mod_scale = torch.norm(y_mod, dim=-1) / (torch.norm(y_raw, dim=-1) + 1e-6)
#                 results[f"layer_{idx}_mod_scale"].append(mod_scale.detach().cpu())

#         block.register_forward_hook(hook_fn)

# register_hooks(model)

text = "Once upon a time in a distant land, a mysterious castle stood tall."
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    _ = model(**inputs)






############################ Plotting modulator curves ###############################
# import matplotlib.pyplot as plt
# import numpy as np

# cmap = plt.get_cmap("tab20")  # Color map
# num_tokens = len(inputs["input_ids"][0])
# colors = [cmap(i / num_tokens) for i in range(num_tokens)]

# first_line_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
# second_line_modules = ["up_proj", "gate_proj", "down_proj"]
# all_modules = first_line_modules + second_line_modules

# # Create subplots (2 rows × 4 columns)
# fig, axes = plt.subplots(2, 4, figsize=(30, 12))
# axes = axes.flatten()

# tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

# # Plot for each module
# for idx, module in enumerate(all_modules):
#     ax = axes[idx]

#     modulator_values = []
#     for n, m in model.named_modules():
#         if module in n and f"{module}." not in n:
#             modulator_values.append(m.glia_weight_scalar_wise.detach().cpu().numpy())

#     modulator_values = np.stack(modulator_values, axis=0)  # (num_layers, batch, seq_len)
#     modulator_values = modulator_values[:, 0, :]  # Take batch 0

#     for t_idx in range(modulator_values.shape[1]):
#         ax.plot(modulator_values[:, t_idx], label=tokens[t_idx], color=colors[t_idx])

#     ax.set_title(f"{module}", fontsize=18)
#     ax.set_xlabel("Layer", fontsize=18)
#     ax.set_ylabel("Modulator Value", fontsize=18)
#     ax.grid(True)

# # Now, remove the bottom-right (last) subplot and use it for legend
# fig.delaxes(axes[-1])  # Remove 8th plot (position 2,4)

# # Create a new axes at the deleted position for the legend
# from matplotlib.legend import Legend

# legend_ax = fig.add_subplot(2, 4, 8)
# legend_ax.axis('off')  # No axes

# legend = Legend(
#     legend_ax,
#     handles=[plt.Line2D([0], [0], color=colors[i], lw=3) for i in range(num_tokens)],
#     labels=[tokens[i] for i in range(num_tokens)],
#     loc='center',
#     frameon=False,
#     ncol=1,
#     title="Tokens",
#     fontsize=14,
# )
# legend_ax.add_artist(legend)

# plt.tight_layout()
# plt.savefig("modulator_values_per_token_all_modules_with_legend.png", dpi=500)
# plt.show()


import os
import re
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import colors
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


# -------- CONFIG --------
batch_index = 0
module_name_filters = ('down_proj',)  # only down_proj
output_dir = "layer_3d_plots"
token_stride = 1
channel_stride = 1

ROWS, COLS = 4, 3           # 12 layers
FIGSIZE = (COLS * 6.5, ROWS * 4.5)
CMAP = "coolwarm"           # same colormap for all subplots
LEVELS = 40                 # contour detail
# ------------------------

B = int(inputs["input_ids"].shape[0])
T = int(inputs["input_ids"].shape[1])
os.makedirs(output_dir, exist_ok=True)

def extract_layer_index(name: str):
    m = re.search(r"(?:model\.)?layers\.(\d+)\.", name)
    return int(m.group(1)) if m else None

def to_btc_array(g, B_expected: int, T_expected: int):
    if not isinstance(g, torch.Tensor):
        g = torch.as_tensor(g)
    g = g.detach().float().cpu()

    if g.dim() == 3:  # [B,T,C]
        assert g.shape[0] == B_expected and g.shape[1] == T_expected
        return g.numpy()
    if g.dim() == 2:  # [B*T,C] or [T,C]
        bt, c = g.shape
        if B_expected * T_expected == bt:
            return g.view(B_expected, T_expected, c).numpy()
        if T_expected == bt and B_expected == 1:
            return g.unsqueeze(0).numpy()
        return g.view(1, bt, c).numpy()
    if g.dim() == 1:  # [C]
        c = g.shape[0]
        return g.view(1, 1, c).repeat(B_expected, T_expected, 1).numpy()
    raise ValueError(f"Unsupported glia_weight_channel_wise shape: {tuple(g.shape)}")

def sanitize(s: str, limit=80):
    s = re.sub(r"[^a-zA-Z0-9._-]+", "_", s)
    return s[:limit] if s else "mod"

# ---- Collect per-layer tensors for ONLY down_proj ----
layer_items = []  # (layer_idx, V[T,C], leaf, full_name)
for name, mod in model.named_modules():
    if not any(sub in name for sub in module_name_filters):
        continue
    if any(f"{sub}." in name for sub in module_name_filters):
        continue
    if not hasattr(mod, "glia_weight_channel_wise"):
        continue

    layer_idx = extract_layer_index(name)
    if layer_idx is None:
        continue

    g = getattr(mod, "glia_weight_channel_wise", None)
    if g is None:
        continue

    V_btc = to_btc_array(g, B_expected=B, T_expected=T)  # [B,T,C]
    V = V_btc[batch_index][::token_stride, ::channel_stride]  # [T,C]
    layer_items.append((layer_idx, V, name.split(".")[-1], name))

layer_items.sort(key=lambda x: x[0])

# ---- One 4×3 figure; same cmap, per-plot normalization; contour style ----
fig, axes = plt.subplots(
    ROWS, COLS,
    subplot_kw={"projection": "3d"},
    figsize=FIGSIZE,
    constrained_layout=True
)

axes = axes.ravel()

for i in range(ROWS * COLS):
    ax = axes[i]
    if i < len(layer_items):
        layer_idx, V, leaf, full_name = layer_items[i]
        T_len, C_len = V.shape
        X = np.tile(np.arange(C_len)[None, :], (T_len, 1))  # channels
        Y = np.tile(np.arange(T_len)[:, None], (1, C_len))  # tokens

        # per-plot vmin/vmax
        vmin, vmax = float(np.nanmin(V)), float(np.nanmax(V))
        if vmin == vmax:
            vmin, vmax = vmin - 1e-6, vmax + 1e-6
        norm = colors.Normalize(vmin=vmin, vmax=vmax)

        # contour style
        ct = ax.contour3D(X, Y, V, levels=LEVELS, cmap=CMAP, norm=norm)

        # inset colorbar to the RIGHT of the axes (outside), so it won't overlap the Z-axis
        cax = inset_axes(
            ax,
            width="2%", height="60%",      # size of the colorbar
            loc="center left",
            bbox_to_anchor=(1.10, 0., 1, 1),  # push outside; increase 1.06 if still close
            bbox_transform=ax.transAxes,
            borderpad=0
        )
        cbar = fig.colorbar(ct, cax=cax)
        cbar.set_label("modulator value")


        ax.set_title(f"Layer {layer_idx} — {leaf}", fontsize=10)
        ax.set_xlabel("Channel index")
        ax.set_ylabel("Token index")
        ax.set_zlabel("Modulator")
        ax.view_init(elev=30, azim=135)
    else:
        ax.set_visible(False)

# plt.tight_layout()
out_path = os.path.join(output_dir, "down_proj_all_layers_grid_4x3_contour_coolwarm.png")
plt.savefig(out_path, dpi=500)
plt.close(fig)
print(f"Saved grid figure to: {os.path.abspath(out_path)}")










########## PLOT ALPHA ##########
# import re
# import numpy as np
# import matplotlib.pyplot as plt
# import torch

# # Modules to visualize
# first_line_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
# second_line_modules = ["up_proj", "gate_proj", "down_proj"]
# all_modules = first_line_modules + second_line_modules  # 7 total

# # Create subplots (2 rows × 4 columns)
# fig, axes = plt.subplots(2, 4, figsize=(30, 12))
# axes = axes.flatten()

# def extract_layer_index(name: str):
#     """
#     Try to extract a numeric layer index from common patterns like:
#     'model.layers.12.self_attn.q_proj', 'layers.3.mlp.up_proj', etc.
#     Returns an int or None.
#     """
#     m = re.search(r"layers\.(\d+)\.", name)
#     return int(m.group(1)) if m else None

# for idx, module in enumerate(all_modules):
#     ax = axes[idx]

#     alphas = []
#     layer_indices = []

#     for n, m in model.named_modules():
#         # Match the module itself, not its children (nn.Linear has no submodules anyway)
#         if module in n and f"{module}." not in n:
            
#             # Only consider modules that actually have a learnable alpha
#             if hasattr(m, "expert_nonlinear") and hasattr(m.expert_nonlinear, "alpha"):
                
#                 a = m.expert_nonlinear.alpha
#                 print(n, a)
#                 if isinstance(a, torch.Tensor):
#                     # Ensure scalar: handle shape [], [1], or any accidental extra dims
#                     a = a.detach().float().cpu().view(-1).mean().item()
#                 else:
#                     a = float(a)
#                 alphas.append(a)
#                 layer_indices.append(extract_layer_index(n))

#     if not alphas:
#         ax.set_title(f"{module} (no alpha found)", fontsize=15)
#         ax.axis("off")
#         continue

#     # If we could parse layer indices, sort by them; otherwise keep discovered order
#     if any(li is not None for li in layer_indices):
#         pairs = [(li if li is not None else 10**9, a) for li, a in zip(layer_indices, alphas)]
#         pairs.sort(key=lambda x: x[0])
#         xs = [p[0] for p in pairs if p[0] != 10**9]
#         ys = [p[1] for p in pairs if p[0] != 10**9]
#         # Fallback: include unsorted items with unknown layer index appended
#         unknown_ys = [p[1] for p in pairs if p[0] == 10**9]
#         if unknown_ys:
#             xs += list(range((xs[-1] + 1) if xs else 0, (xs[-1] + 1) + len(unknown_ys)))
#             ys += unknown_ys
#     else:
#         xs = list(range(len(alphas)))
#         ys = alphas

#     ax.plot(xs, ys, marker="o", linewidth=2)
#     ax.set_title(f"{module}", fontsize=15)
#     ax.set_xlabel("Layer", fontsize=13)
#     ax.set_ylabel("α (module scalar)", fontsize=13)
#     ax.grid(True, alpha=0.3)

# # Remove the extra (8th) subplot since we have 7 modules
# fig.delaxes(axes[-1])

# plt.tight_layout()
# plt.savefig("alpha_values_per_module.png", dpi=300)
# plt.show()






# ############################### Plotting correlation heatmaps ###############################
# import matplotlib.pyplot as plt
# import seaborn as sns

# def plot_results_heatmap(results_tensor, label="Cosine Similarity"):
#     """
#     Plot heatmap from stacked result tensor: [num_layers, batch, seq]
#     """
#     data = results_tensor[:, 0, :].transpose(0, 1).numpy()  # shape: [seq_len, num_layers]
#     plt.figure(figsize=(15, 6))
#     sns.heatmap(data, cmap='viridis', center=0 if "cosine" in label.lower() else None,
#                 xticklabels=True, yticklabels=True)
#     plt.xlabel("Layer Index")
#     plt.ylabel("Token Index")
#     plt.title(f"{label} Heatmap: Token x Layer")
#     plt.tight_layout()
#     plt.savefig(f"{label}_heatmap_modulation.png")
#     plt.show()


# def stack_results(results, metric_name="cosine"):
#     """
#     Stack results[layer_0_metric, ..., layer_N_metric] -> [num_layers, batch, seq_len]
#     """
#     layers = sorted([k for k in results if metric_name in k])
#     tensors = [torch.cat(results[layer], dim=0).unsqueeze(0) for layer in layers]
#     return torch.cat(tensors, dim=0)  # shape: [num_layers, batch, seq_len]

# cosine_tensor = stack_results(results, metric_name="cosine")
# normratio_tensor = stack_results(results, metric_name="norm_ratio")
# # print(normratio_tensor)  # [num_layers, batch, seq_len]
# # plot_histogram("layer", 5)  # for norm ratio
# # plot_results_heatmap(cosine_tensor, label="Cosine Similarity")
# # plot_results_heatmap(normratio_tensor, label="Norm Ratio")

# def plot_results_heatmap_with_tokens(results_tensor, tokenizer, input_ids, label="Cosine Similarity"):
#     data = results_tensor[:, 0, :].transpose(0, 1).numpy()  # [seq_len, num_layers]
#     tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
#     # plt.figure(figsize=(15, 10))
#     if label == "Cosine Similarity":
#         sns.heatmap(data, cmap='viridis', center=0 if "cosine" in label.lower() else None, vmax=1,
#                 xticklabels=True, yticklabels=tokens)
#     else:
#         sns.heatmap(data, cmap='viridis', center=0 if "cosine" in label.lower() else None, vmax=1,
#                 xticklabels=True, yticklabels=False)
#     plt.xlabel("Layer Index", fontsize=20)
#     if label == "Cosine Similarity":
#         plt.ylabel("Token", fontsize=20)
#     # plt.title(f"{label} Heatmap: Token x Layer")
#     plt.tight_layout()
#     # plt.savefig(f"{label}_heatmap_with_tokens_modulation.png")
#     # plt.show()
# print(normratio_tensor)
# plt.subplots(1, 2, figsize=(20, 6))
# plt.subplot(1, 2, 1)
# plot_results_heatmap_with_tokens(cosine_tensor, tokenizer, inputs["input_ids"], label="Cosine Similarity")
# plt.subplot(1, 2, 2)
# plot_results_heatmap_with_tokens(normratio_tensor, tokenizer, inputs["input_ids"], label="Norm Ratio")
# plt.savefig("heatmap_with_tokens_modulation.png", dpi=300)
