import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from mpl_toolkits.mplot3d.art3d import Line3DCollection

def plot_3d_feat(obj, layer_id, model_name=None,
                 token_step=10, channel_step=500,
                 cmap_name="coolwarm",
                 c_stride=1, t_stride=1,
                 top_k_ignore=1,
                 cap_percentile=None,
                 view=None,
                 save_path=None, dpi=600):

    fig = plt.figure(figsize=(14, 6))
    ax  = fig.add_subplot(1, 1, 1, projection='3d')

    z_all = obj[f"{layer_id}"][0].detach().cpu().numpy().T
    C_all, T_all = z_all.shape

    ch_idx = np.arange(0, C_all, c_stride, dtype=int)
    tk_idx = np.arange(0, T_all, t_stride, dtype=int)
    Z = z_all[np.ix_(ch_idx, tk_idx)]
    C, T = Z.shape

    X = np.repeat(ch_idx[:, None], T, axis=1).astype(float)
    Y = np.repeat(tk_idx[None, :],  C, axis=0).astype(float)

    zflat = z_all.ravel()
    zmin  = float(zflat.min())
    zmax  = float(zflat.max())

    vmax_cap = zmax
    if top_k_ignore is not None and top_k_ignore > 0 and zflat.size > top_k_ignore:
        kth = np.partition(zflat, -(top_k_ignore+1))[-(top_k_ignore+1)]
        vmax_cap = float(kth)
    if cap_percentile is not None:
        vmax_pct = float(np.percentile(zflat, cap_percentile))
        vmax_cap = min(vmax_cap, vmax_pct)

    if vmax_cap <= zmin:
        vmax_cap = zmin + 1e-12

    norm = Normalize(vmin=zmin, vmax=vmax_cap, clip=True)
    try:
        cmap = plt.get_cmap(cmap_name)
    except Exception:
        try:
            cmap = mpl.colormaps.get_cmap(cmap_name)
        except Exception:
            cmap = mpl.colormaps.get_cmap("viridis")

    starts = np.stack([X.ravel(), Y.ravel(), np.zeros_like(Z).ravel()], axis=1)
    ends   = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=1)
    segments = np.stack([starts, ends], axis=1)

    colors = cmap(norm(Z.ravel()))
    lc = Line3DCollection(segments, colors=colors, linewidths=0.4, alpha=1.0)
    ax.add_collection3d(lc)

    ax.set_xlim(0, C_all - 1)
    ax.set_ylim(0, T_all - 1)
    ax.set_zlim(min(0.0, zmin), max(0.0, zmax))

    xticks = np.arange(channel_step - 1, C_all, channel_step, dtype=int)
    if xticks.size == 0:
        xticks = np.linspace(0, max(C_all - 1, 0), num=min(5, C_all), dtype=int)
    ax.set_xticks(xticks); ax.set_xticklabels([str(v + 1) for v in xticks], fontsize=10)

    yticks = np.arange(token_step - 1, T_all, token_step, dtype=int)
    if yticks.size == 0:
        yticks = np.linspace(0, max(T_all - 1, 0), num=min(5, T_all), dtype=int)
    ax.set_yticks(yticks); ax.set_yticklabels([str(v + 1) for v in yticks], fontsize=10)

    ax.set_xlabel("Channels", labelpad=10, fontsize=12, fontweight="bold")
    ax.set_ylabel("Tokens",   labelpad=12, fontsize=12, fontweight="bold")
    ax.set_zlabel("Hidden",   labelpad=6,  fontsize=11)

    sm = ScalarMappable(norm=norm, cmap=cmap); sm.set_array(z_all)
    cbar = plt.colorbar(sm, ax=ax, pad=0.12, fraction=0.03)
    cbar.set_label("Hidden (color scale; capped)", rotation=90)

    ax.tick_params(axis='x', which='major', pad=-2)
    ax.tick_params(axis='y', which='major', pad=-2)
    ax.tick_params(axis='z', which='major', pad=-1)
    if view is not None:
        elev, azim = view
        ax.view_init(elev=elev, azim=azim)
    if save_path is not None:
        plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
        print(f"图像已保存到: {save_path}")

    plt.tight_layout()
    plt.show()
    plt.close(fig)



def plot_2d_channel_feat(obj, layer_id, model_name, channels=[753, 3848]):
    num_tokens = len(obj["seq"])
    zdata = obj[f"{layer_id}"][0].abs().numpy().T
    inp_seq = obj["seq"]
    inp_seq = [x if x != "<0x0A>" else r"\n" for x in inp_seq]
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    for channel_id in channels:
        if channel_id < zdata.shape[0]:
            channel_data = zdata[channel_id, :]
            ax.plot(np.arange(num_tokens), channel_data, linewidth=2, label=f'Channel {channel_id}')
    
    ax.set_xticks(np.arange(num_tokens))
    ax.set_xticklabels(inp_seq, rotation=60, ha='right', fontsize=10)
    ax.set_xlabel('Token Position', fontsize=12)
    
    ax.set_ylabel('Activation Magnitude', fontsize=12)
    ax.set_ylim(0, np.max(zdata) * 1.1)
    
    ax.set_title(f'Channel Activation Patterns\nLayer {layer_id+1} - {model_name}', fontsize=14, pad=15)
    ax.legend(fontsize=12, loc='upper right')
