import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import numpy as np
import trimesh
import glob

from torch.utils.data import TensorDataset, DataLoader
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.default_config import DefaultConfig
from mpl_toolkits.mplot3d import Axes3D
from sw2 import Wasserstein_Distance
import numpy as np


def plot_pointclouds(X_path=None, X=None, label="airplane", k=5, save_dir="saved_img"):
    os.makedirs(save_dir, exist_ok=True)

    if X is None:
        X = torch.load(X_path)
    X = X.cpu().numpy()

    n_samples = min(k, X.shape[0])
    fig = plt.figure(figsize=(4 * n_samples, 4))

    for i in range(n_samples):
        pc = X[i]

        ax = fig.add_subplot(1, n_samples, i + 1, projection='3d')
        ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], s=1, c=pc[:, 2], cmap="viridis")
        ax.set_title(f"Sample {i}", fontsize=15)
        ax.axis("off")

    plt.savefig(f"{save_dir}/{label}.png")





def plot_barycenters(
    wormhole,
    list_samples,
    class_names,
    out_dir="saved_barycenter",
    method_label="fastSP-Wormhole",
    row_label=None,
    elev=18, azim=125,
    s=6.0,
    max_cols=5
):
    os.makedirs(out_dir, exist_ok=True)

    bary_list = []
    used_names = []

    for haha, (samples, cname) in enumerate(zip(list_samples, class_names)):
        print(samples.shape)
        z = wormhole.encoder(samples.to(wormhole.config.device).to(wormhole.config.dtype))
        print(z.shape)
        mean_z = torch.mean(z, dim=0, keepdim=True)
        print(mean_z.shape)
        pc = wormhole.decoder(mean_z.to(wormhole.config.device, dtype=wormhole.config.dtype))
        print(pc.shape)
        bary_list.append(pc.squeeze(0).detach().cpu().numpy())
        used_names.append(cname)

    C = len(bary_list)

    col_w = 2.0
    fig_w = col_w * C
    fig_h = 2.0
    fig, axes = plt.subplots(
        1, C,
        figsize=(fig_w, fig_h), dpi=300,
        subplot_kw={'projection': '3d'},
        gridspec_kw={'wspace': 0.02, 'hspace': 0.0}
    )
    if C == 1:
        axes = [axes]
    for j, (pc, cname) in enumerate(zip(bary_list, used_names)):
        mins = pc.min(axis=0); maxs = pc.max(axis=0)
        ctr  = (mins + maxs) / 2.0
        half = (maxs - mins).max() / 2.0
        lo, hi = ctr - half, ctr + half

        ax = axes[j]
        ax.scatter(
            pc[:, 0], pc[:, 1], pc[:, 2],
            s=s, c=pc[:, 2], cmap="viridis",
            alpha=0.95, linewidths=0, edgecolors="none"
        )

        ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
        ax.text2D(0.5, -0.12, cname, transform=ax.transAxes,
                  ha="center", va="top", fontsize=12)

    fig.subplots_adjust(left=0.02, right=0.998, bottom=0.14, top=0.92, wspace=0.02)

    if row_label:
        ax_left = axes[0]
        pos = ax_left.get_position()
        y_center = 0.5 * (pos.y0 + pos.y1)
        x_text = max(0.015, pos.x0 - 0.02)
        fig.text(x_text, y_center, method_label, rotation=90,
                 ha="center", va="center", fontsize=10, weight="bold")

    out_png = os.path.join(out_dir, f"barycenters_{method_label}.png")
    out_pdf = os.path.join(out_dir, f"barycenters_{method_label}.pdf")
    fig.savefig(out_png, bbox_inches="tight", pad_inches=0.02)
    fig.savefig(out_pdf, bbox_inches="tight", pad_inches=0.02, dpi=300)
    plt.close(fig)
    print(f"[plot_row_compact] Saved to {out_png} and {out_pdf}")


if __name__ == "__main__":

    method_label = "RG-seo Wormhole"
    prefix_saved_model = "fast_wormhole_interpolation_modelnet40_10classes/optimal_alpha_general/fake_decode_loss/num_200/sw_pwd_ebsw_est_maxsw_minswgg/sw_pwd_ebsw_est_maxsw_minswgg_lr0.0001_epoch2000.pth"
    config = DefaultConfig(
        n_points=2048,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        batch_size=32,
        dtype=torch.float32,
        emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.1,
        input_dim=3, lr=1e-4, epochs=250, decay_steps=200
    )

    model = TransformerAutoencoder(
        config=config, seq_len=config.n_points, inp_dim=config.input_dim
    ).to(config.device).to(config.dtype)

    wormhole = Wormhole(transformer=model, config=config, run_dir=None)
    wormhole.load_state_dict(torch.load(prefix_saved_model, map_location=config.device)["model_state_dict"])

    class_names = ["airplane", "bed", "chair", "cup", "lamp", "toilet"]
    list_samples = []
    list_prototypes = []
    for label in class_names:
        x = torch.load(f"saved_barycenter_data/barycenter_{label}.pt")
        dataloader = DataLoader(TensorDataset(x), batch_size=config.batch_size, shuffle=False, drop_last=False)
        list_samples.append(x)


    plot_barycenters(
        wormhole=wormhole,
        list_samples=list_samples,
        class_names=class_names,
        out_dir="saved_barycenter_data",
        method_label=method_label,
        row_label="barycenter",
        elev=18, azim=125,
        s=6.0,
        max_cols=5
    )