import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import numpy as np
import trimesh
import glob

from torch.utils.data import TensorDataset, DataLoader
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.default_config import DefaultConfig
from mpl_toolkits.mplot3d import Axes3D  # cần import để kích hoạt 3D projection
from sw2 import Wasserstein_Distance
import numpy as np

def plot_pointclouds(X_path=None, X=None, label="airplane", k=5, save_dir="saved_img"):
    os.makedirs(save_dir, exist_ok=True)

    if X is None:
        X = torch.load(X_path)
    X = X.cpu().numpy()

    n_samples = min(k, X.shape[0])
    fig = plt.figure(figsize=(4 * n_samples, 4))

    for i in range(n_samples):
        pc = X[i]

        ax = fig.add_subplot(1, n_samples, i + 1, projection='3d')
        ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], s=1, c=pc[:, 2], cmap="viridis")
        ax.set_title(f"Sample {i}", fontsize=15)
        ax.axis("off")
    plt.savefig(f"{save_dir}/{label}.png")


def plot_interpolation(
    wormhole,
    sample1,
    sample2,
    class_names=("A","B"),
    out_dir="saved_interpolation_data",
    method_label="fastSP-Wormhole",
    row_label=None,
    s=6.0,
    n_cols=5
):
    os.makedirs(out_dir, exist_ok=True)

    ts = torch.linspace(0, 1, steps=n_cols, device=wormhole.config.device, dtype=wormhole.config.dtype)
    print(ts)

    sample1 = sample1.to(wormhole.config.device).to(wormhole.config.dtype)
    sample2 = sample2.to(wormhole.config.device).to(wormhole.config.dtype)
    # encode 2 samples
    concat_samples = torch.cat([sample1, sample2], dim=0)

    z1 = wormhole.encoder(sample1)
    z2 = wormhole.encoder(sample2)

    pcs_decoded = []
    for t in ts.tolist():
        zt = (1 - t) * z1 + t * z2
        pc = wormhole.decoder(zt).squeeze(0).detach().cpu().numpy()
        pcs_decoded.append(pc)

    # ===== plot =====
    fig, axes = plt.subplots(
        1, n_cols,
        figsize=(2.0 * n_cols, 2.2), dpi=300,
        subplot_kw={'projection': '3d'},
        gridspec_kw={'wspace': 0.02, 'hspace': 0.0}
    )

    if n_cols == 1:
        axes = [axes]

    for j, (pc, t) in enumerate(zip(pcs_decoded, ts.tolist())):
        mins = pc.min(axis=0); maxs = pc.max(axis=0)
        ctr  = (mins + maxs) / 2.0
        half = (maxs - mins).max() / 2.0
        lo, hi = ctr - half, ctr + half

        ax = axes[j]
        ax.scatter(pc[:,0], pc[:,1], pc[:,2],
                   s=s, c=pc[:,2], cmap="viridis",
                   alpha=0.95, linewidths=0, edgecolors="none")
        ax.set_xlim(lo[0], hi[0]); ax.set_ylim(lo[1], hi[1]); ax.set_zlim(lo[2], hi[2])
        try: ax.set_box_aspect([1,1,1])
        except: pass
        ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])

        # ghi nhãn 2 đầu
        if j == 0:
            ax.text2D(0.5, -0.12, class_names[0], transform=ax.transAxes,
                      ha="center", va="top", fontsize=10)
        if j == n_cols - 1:
            ax.text2D(0.5, -0.12, class_names[1], transform=ax.transAxes,
                      ha="center", va="top", fontsize=10)

    # thêm nhãn cho hàng
    if row_label:
        ax_left = axes[0]
        pos = ax_left.get_position()
        y_center = 0.5 * (pos.y0 + pos.y1)
        x_text = max(0.015, pos.x0 - 0.02)
        fig.text(x_text, y_center, row_label, rotation=90,
                 ha="center", va="center", fontsize=10, weight="bold")

    out_png = os.path.join(out_dir, f"interp_{method_label}.png")
    out_pdf = os.path.join(out_dir, f"interp_{method_label}.pdf")
    fig.savefig(out_png, bbox_inches="tight", pad_inches=0.02)
    fig.savefig(out_pdf, bbox_inches="tight", pad_inches=0.02, dpi=300)
    print(f"[plot_interpolation] Saved to {out_png} and {out_pdf}")




if __name__ == "__main__":

    label_name = "lamp"
    X = torch.load(f"compare_interpolation_modelnet40_all_classes/{label_name}/X_train.pt")
    method_label = "RG-seo Wormhole"
    prefix_saved_model = "fast_wormhole_interpolation_modelnet40_10classes/optimal_alpha_general/fake_decode_loss/num_200/sw_pwd_ebsw_est_maxsw_minswgg/sw_pwd_ebsw_est_maxsw_minswgg_lr0.0001_epoch2000.pth"

    config = DefaultConfig(
        n_points=2048,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        batch_size=32,
        dtype=torch.float32,
        emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.1,
        input_dim=3, lr=1e-4, epochs=250, decay_steps=200
    )

    model = TransformerAutoencoder(
        config=config, seq_len=config.n_points, inp_dim=config.input_dim
    ).to(config.device).to(config.dtype)

    wormhole = Wormhole(transformer=model, config=config, run_dir=None)
    wormhole.load_state_dict(torch.load(prefix_saved_model, map_location=config.device)["model_state_dict"])

    plot_interpolation(
        wormhole=wormhole,
        sample1=X[12].unsqueeze(0),
        sample2=X[13].unsqueeze(0),
        class_names=("Lamp A", "Lamp B"),
        n_cols=5,
        row_label=method_label,
        method_label=method_label,
        out_dir="saved_interpolation"
    )


