import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# Configuration settings for matplotlib
matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'
matplotlib.rcParams.update({
    'font.size': 18,
    'axes.labelsize': 20,
    'axes.titlesize': 24,
    'figure.titlesize': 28
})
matplotlib.rcParams['text.usetex'] = False

# def plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size):
#     model_title={"dinov2_reg": f"DINOv2-reg ViT-{model_size}", "mistral_7b": "Mistral-7B", 
#                "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B", "mistral_moe":"Mixtral-8x7B",
#                "CLIP":f"CLIP-{model_size}"}

#     num_channels = feat.shape[2]

#     inp_seq = ["CLS",
#                "patch 1", "patch 67", "patch 68", "patch n"]

#     xbase_index = [0,1, 67,68,197]
#     num_tokens = len(xbase_index)
#     xdata = np.array([xbase_index for i in range(num_channels)])
#     ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
#     zdata = feat[0,:num_tokens,:].abs().numpy().T
#     ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5)

#     ax.set_title(model_title[model_name]+f", Layer {layer_id+1}", fontsize=20, fontweight="bold", y=1.015)

#     ax.set_yticks([980, 2372, 61,1849], [980, 2372, 61,1849], fontsize=15)

#     xbase_index = [68,67]
#     inp_seq = ["68", "67"]
#     ax.set_xticks(xbase_index, inp_seq, rotation=60, fontsize=16)
#     ax.tick_params(axis='x', which='major', pad=-4)
#     plt.setp(ax.get_xticklabels(), rotation=50, ha="right", va="center", rotation_mode="anchor")

#     ax.set_zticks([0, 50, 100], ["0", "50", "100"], fontsize=16)
#     plt.setp(ax.get_yticklabels(), ha="left", va="center",rotation_mode="anchor")
#     plt.setp(ax.get_zticklabels(), ha="left", va="top", rotation_mode="anchor")

#     ax.tick_params(axis='x', which='major', pad=-5)
#     ax.tick_params(axis='y', which='major', pad=-3)
#     ax.tick_params(axis='z', which='major', pad=-5)

from mpl_toolkits.mplot3d import Axes3D
# def plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size):
#     model_title={"dinov2_reg": f"DINOv2-reg ViT-{model_size}", "mistral_7b": "Mistral-7B", 
#                "llama2_13b": "LLaMA-2-13B", "llama2_70b": "LLaMA-2-70B", "mistral_moe":"Mixtral-8x7B",
#                "CLIP":f"CLIP-{model_size}"}

#     num_channels = feat.shape[2]

#     xbase_index_data = [140, 143,  55,  68,  99,  54, 0,  39,  78, 117, 156, 196]
#     # num_tokens는 이제 zdata의 shape을 통해 얻는 것이 더 정확하지만,
#     # ydata 생성에만 사용되고 len(xbase_index_data)와 같으므로 그대로 둬도 문제는 없습니다.
#     num_tokens = len(xbase_index_data)
    
#     xdata = np.array([xbase_index_data for i in range(num_channels)])
#     ydata = np.array([np.ones(num_tokens) * i for i in range(num_channels)])
    
#     # --- ✨ 여기가 수정된 부분입니다 ✨ ---
#     # :num_tokens 대신 xbase_index_data를 직접 사용해 원하는 토큰의 데이터를 정확히 선택합니다.
#     zdata = feat[0, xbase_index_data, :].abs().numpy().T
    
#     ax.plot_wireframe(xdata, ydata, zdata, rstride=0, color="royalblue", linewidth=2.5)

#     ax.set_title(model_title[model_name]+f", Layer {layer_id+1}", fontsize=20, fontweight="bold", y=1.015)

#     # ax.set_yticks([1435, 2078, 2922], [1435, 2078, 2922], fontsize=15)
    
#     ax.set_zlim(0, zdata.max() * 1.1)
    
#     # # ax.set_xticks([])
#     # ax.set_xticklabels([])

#     # ax.text(67, 0, 0, '67', fontsize=16, ha='center', va='bottom')
#     # ax.text(68, 0, 0, '68', fontsize=16, ha='left', va='center')

#     ax.set_zticks([50], ["50"], fontsize=16)
#     plt.setp(ax.get_yticklabels(), ha="left", va="center", rotation_mode="anchor")
#     plt.setp(ax.get_zticklabels(), ha="left", va="top", rotation_mode="anchor")

#     ax.view_init(elev=25, azim=-75)

#     ax.tick_params(axis='x', which='major', pad=-5)
#     ax.tick_params(axis='y', which='major', pad=-3)
#     ax.tick_params(axis='z', which='major', pad=-5)
def plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size):
    model_title = {
        "dinov2_reg": f"DINOv2-reg ViT-{model_size}",
        "mistral_7b": "Mistral-7B", 
        "llama2_13b": "LLaMA-2-13B", 
        "llama2_70b": "LLaMA-2-70B", 
        "mistral_moe":"Mixtral-8x7B",
        "CLIP":f"CLIP-{model_size}"
    }

    num_channels = feat.shape[2]

    # 기본 토큰 index
    xbase_index_data = [0, 18, 114, 196]

    # outlier 토큰 index (예시: 일부만 선택)
    outlier_xbase_index_data = [55, 68, 143, 99]

    # --- 일반 토큰 ---
    normal_index = [i for i in xbase_index_data if i not in outlier_xbase_index_data]

    # 일반 토큰 데이터
    xdata_normal = np.array([normal_index for _ in range(num_channels)])
    ydata_normal = np.array([np.ones(len(normal_index)) * i for i in range(num_channels)])
    zdata_normal = feat[0, normal_index, :].abs().numpy().T

    # Outlier 토큰 데이터
    xdata_outlier = np.array([outlier_xbase_index_data for _ in range(num_channels)])
    ydata_outlier = np.array([np.ones(len(outlier_xbase_index_data)) * i for i in range(num_channels)])
    zdata_outlier = feat[0, outlier_xbase_index_data, :].abs().numpy().T

    # --- 플롯 ---
    # 일반 토큰: 파란색
    ax.plot_wireframe(xdata_normal, ydata_normal, zdata_normal, rstride=0, color="royalblue", linewidth=2.0)
    # Outlier 토큰: 빨간색
    ax.plot_wireframe(xdata_outlier, ydata_outlier, zdata_outlier, rstride=0, color="royalblue", linewidth=2.5)

    # --- 옵션들 ---
    ax.set_title(model_title[model_name] + f", Layer {layer_id+1}",
                 fontsize=20, fontweight="bold", y=1.015)

    ax.set_zlim(0, max(zdata_normal.max(), zdata_outlier.max()) * 1.1)

    ax.set_zticks([50], ["50"], fontsize=16)
    plt.setp(ax.get_yticklabels(), ha="left", va="center", rotation_mode="anchor")
    plt.setp(ax.get_zticklabels(), ha="left", va="top", rotation_mode="anchor")

    ax.view_init(elev=25, azim=-75)

    ax.tick_params(axis='x', which='major', pad=-5)
    ax.tick_params(axis='y', which='major', pad=-3)
    ax.tick_params(axis='z', which='major', pad=-5)

def plot_3d_feat_vit(feat, layer_id, model_name, model_size, savedir):
    fig = plt.figure(figsize=(8,6))
    fig.tight_layout() # Or equivalently,  "plt.tight_layout()"
    plt.subplots_adjust(wspace=0.)

    ax = fig.add_subplot(1,1, 1, projection='3d')
    plot_3d_feat_vit_sub(ax, feat, layer_id, model_name, model_size)
    plt.savefig(os.path.join(savedir, f"{model_name}_{model_size}_layer_{layer_id+1}.png"), bbox_inches="tight", dpi=200)


def plot_layer_ax_vit_sub(ax, mean, model_family, model_size, colors=["royalblue", "darkorange", "forestgreen", "black"]):
    model_title={"dinov2_reg": "DINOv2-reg", 
                 "dinov2": "DINOv2", "mae": "MAE", "open_clip": "Open CLIP", "openai_clip": "OpenAI CLIP", 
                 "vit_orig": "ViT", "samvit": "SAM-ViT"}

    x_axis = np.arange(mean.shape[-1])+1
    for i in range(3):
        ax.plot(x_axis, mean[i], label=f"Top {i+1}", color=colors[i], 
                     linestyle="-",  marker="o", markerfacecolor='none', markersize=5)

    ax.plot(x_axis, mean[-1], label=f"median", color=colors[-1], 
                     linestyle="-",  marker="v", markerfacecolor='none', markersize=5)

    ax.set_title(model_title[model_family]+f" ViT-{model_size}", fontsize=18, fontweight="bold")
    ax.set_ylabel("Magnitudes", fontsize=18)

    num_layers = mean.shape[1]
    xtick_label = [1, num_layers//4, num_layers//2, num_layers*3//4, num_layers]
    ax.set_xticks(xtick_label, xtick_label, fontsize=16)

    ax.set_xlabel('Layers', fontsize=18, labelpad=4.0)
    ax.tick_params(axis='x', which='major', pad=2.0)
    ax.tick_params(axis='y', which='major', pad=0.4)
    ax.grid(axis='x', color='0.75')
    ax.grid(axis='y', color='0.75')

def plot_layer_ax_vit(mean, model_family, model_size, savedir):
    fig = plt.figure(figsize=(8,6))
    fig.tight_layout() # Or equivalently,  "plt.tight_layout()"
    plt.subplots_adjust(wspace=0.)

    ax = fig.add_subplot(1,1, 1)
    plot_layer_ax_vit_sub(ax, mean, model_family, model_size)
    plt.savefig(os.path.join(savedir, f"{model_family}_{model_size}.png"), bbox_inches="tight", dpi=200)