import os
import torch
from torch.utils.data import TensorDataset, DataLoader
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.default_config import DefaultConfig
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sw2 import Wasserstein_Distance
import numpy as np
import matplotlib as mpl
from copy import deepcopy
    
num_samples_training = 200
epoch_test = 256
eval_baseline = True
estimate_alpha_general = True

method_label = "RG-seo Wormhole"
dataset_path = "preprocessed_dataset/ModelNet40"
prefix_saved_model = "fast_wormhole_interpolation_modelnet40_10classes/optimal_alpha_general/num_200/sw_pwd_ebsw_est_maxsw_minswgg/lr0.0001_epoch2000.pth"


config = DefaultConfig(
    n_points=2048,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    batch_size=32,
    dtype=torch.float32,
    emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.1,
    input_dim=3, lr=1e-4, epochs=epoch_test, decay_steps=200
)




model = TransformerAutoencoder(
    config=config, seq_len=config.n_points, inp_dim=config.input_dim
).to(config.device).to(config.dtype)

wormhole = Wormhole(transformer=model, config=config, run_dir=None)


def load_each_dataloader_category(
    root_dir,
    category,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=False
):
    d = os.path.join(root_dir, category)
    X = torch.load(os.path.join(d, "X_train.pt"))
    y = torch.load(os.path.join(d, "y_train.pt"))
    dataset = TensorDataset(X, y)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last,
    )

    return loader, X



def load_dataloader_categories(
    root_dir: str,
    categories: list[str],
    batch_size: int = 32,
    shuffle: bool = True,
    num_workers: int = 0,
    pin_memory: bool = True,
    drop_last: bool = False,
):
    assert len(categories) > 0, "categories rỗng."

    X_all, y_all = [], []
    counts_per_cat = {}

    for new_idx, cat in enumerate(categories):
        d = os.path.join(root_dir, cat)
        X = torch.load(os.path.join(d, "X_train.pt"))   # [N, n_points, 3]
        y_orig = torch.load(os.path.join(d, "y_train.pt"))  # bỏ qua mapping cũ

        # gán nhãn mới = index trong input list
        y_new = torch.full((len(y_orig),), new_idx, dtype=torch.long)

        X_all.append(X)
        y_all.append(y_new)
        counts_per_cat[cat] = len(y_new)

    # concat lại
    X_cat = torch.cat(X_all, dim=0)
    y_cat = torch.cat(y_all, dim=0)
    dataset = TensorDataset(X_cat, y_cat)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last,
    )

    print(f"=> Loaded {len(dataset)} samples from {len(categories)} categories.")
    for k, v in counts_per_cat.items():
        print(f"   - {k}: {v}")

    return loader, categories, counts_per_cat



def plot_decode_and_barycenter(
    wormhole, dataloader, prefix_plot="saved_plot_barycenter", num_plot_decode=5
):
    import os, numpy as np, torch, matplotlib.pyplot as plt
    os.makedirs(prefix_plot, exist_ok=True)

    emb_list, recon_list, orig_list = [], [], []
    with torch.no_grad():
        for (xb,) in dataloader:
            xb = xb.to(config.device).to(config.dtype)
            z = wormhole.encoder(xb)
            x_recon = wormhole.decoder(z)
            emb_list.append(z)
            recon_list.append(x_recon)
            orig_list.append(xb)

    X_recon = torch.cat(recon_list, dim=0)
    X_orig  = torch.cat(orig_list,  dim=0)

    k = min(num_plot_decode, X_recon.shape[0])
    for i in range(k):
        x  = X_orig[i].detach().cpu().numpy()
        xr = X_recon[i].detach().cpu().numpy()

        mins = np.minimum(x.min(axis=0), xr.min(axis=0))
        maxs = np.maximum(x.max(axis=0), xr.max(axis=0))
        ctr  = (mins + maxs) / 2
        half = (maxs - mins).max() / 2
        lo, hi = ctr - half, ctr + half

        with torch.no_grad():
            xt  = X_orig[i].to(config.device, dtype=config.dtype)
            xrt = X_recon[i].to(config.device, dtype=config.dtype)
            wd  = Wasserstein_Distance(xt, xrt)
            wd_val = float(wd.detach().cpu().item())

        fig = plt.figure(figsize=(6, 3), dpi=200)

        ax_left = fig.add_subplot(1, 2, 1, projection="3d")
        ax_left.scatter(x[:,0], x[:,1], x[:,2], s=2, c=x[:,2], cmap="viridis",
                        alpha=0.95, linewidths=0, edgecolors="none")
        ax_left.set_xlim(lo[0], hi[0]); ax_left.set_ylim(lo[1], hi[1]); ax_left.set_zlim(lo[2], hi[2])
        try: ax_left.set_box_aspect([1,1,1])
        except: pass
        ax_left.set_xticks([]); ax_left.set_yticks([]); ax_left.set_zticks([])
        ax_left.set_title("orig", fontsize=12, pad=4)

        ax_right = fig.add_subplot(1, 2, 2, projection="3d")
        ax_right.scatter(xr[:,0], xr[:,1], xr[:,2], s=2, c=xr[:,2], cmap="viridis",
                         alpha=0.95, linewidths=0, edgecolors="none")
        ax_right.set_xlim(lo[0], hi[0]); ax_right.set_ylim(lo[1], hi[1]); ax_right.set_zlim(lo[2], hi[2])
        try: ax_right.set_box_aspect([1,1,1])
        except: pass
        ax_right.set_xticks([]); ax_right.set_yticks([]); ax_right.set_zticks([])
        ax_right.set_title("recon", fontsize=12, pad=4)

        fig.suptitle("fastSP-Wormhole", fontsize=26, weight="bold", y=0.98)
        plt.tight_layout(rect=[0, 0.12, 1, 0.92])
        fig.text(0.5, 0.045, f"Wasserstein(orig, recon) = {wd_val:.6f}",
                 ha="center", va="bottom", fontsize=12)

        plt.savefig(f"{prefix_plot}/pair_{i}.png")
        plt.close(fig)

    embeddings = torch.cat(emb_list, dim=0)
    print("embeddings:", embeddings.shape)
    mean_embed = embeddings.mean(dim=0, keepdim=True)
    with torch.no_grad():
        mean_decode = wormhole.decoder(mean_embed)

    cloud_np = mean_decode.squeeze(0).detach().cpu().numpy()
    fig = plt.figure(figsize=(4, 4), dpi=200)
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter(cloud_np[:,0], cloud_np[:,1], cloud_np[:,2], s=3, c=cloud_np[:,2],
               cmap="viridis", alpha=0.9, linewidths=0, edgecolors="none")
    try: ax.set_box_aspect([1,1,1])
    except: pass
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    fig.suptitle("fastSP-Wormhole", fontsize=20, weight='bold', y=0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(f"{prefix_plot}/barycenter.png")
    plt.close(fig)




def plot_grid_samples_recon_compact(
    wormhole,
    samples,
    class_names,
    out_dir="saved_wormhole_reconstruction",
    method_label="fastSP-Wormhole",
    s=2.0,
    elev=18, azim=125,
    max_cols=5
):

    os.makedirs(out_dir, exist_ok=True)
    C = min(max_cols, len(samples))
    samples = samples[:C]
    class_names = class_names[:C]

    cfg = wormhole.config


    col_w = 2.0
    fig_w = col_w * C + 1.2
    fig_h = 4.2
    fig, axes = plt.subplots(
        2, C,
        figsize=(fig_w, fig_h), dpi=200,
        subplot_kw={'projection': '3d'},
        gridspec_kw={'wspace': 0.02, 'hspace': 0.04}
    )

    fig.text(0.02, 0.5, method_label, rotation=90, va="center",
             ha="center", fontsize=12, weight="bold")

    for j, (pc, cname) in enumerate(zip(samples, class_names)):
        x1 = torch.as_tensor(pc)
        if x1.dim() == 2: x1 = x1.unsqueeze(0)
        x1 = x1.to(cfg.device, dtype=cfg.dtype)

        copy_wormhole = deepcopy(wormhole)
        copy_wormhole = wormhole_warmup(copy_wormhole, x1, dec_epochs=10)

        z1  = copy_wormhole.encoder(x1)
        xr1 = copy_wormhole.decoder(z1)

        x_np  = x1.squeeze(0).detach().cpu().numpy()
        xr_np = xr1.squeeze(0).detach().cpu().numpy()

        mins = np.minimum(x_np.min(0), xr_np.min(0))
        maxs = np.maximum(x_np.max(0), xr_np.max(0))
        ctr  = (mins + maxs) / 2
        half = (maxs - mins).max() / 2
        lo, hi = ctr - half, ctr + half

        wd = Wasserstein_Distance(
            x1.squeeze(0), xr1.squeeze(0),
            numItermax=100000, device=cfg.device
        )
        wd_val = float(wd.detach().cpu().item())

        ax_top = axes[0, j] if C > 1 else axes[0]
        ax_top.scatter(x_np[:,0], x_np[:,1], x_np[:,2],
                        s=s, c=x_np[:,2], cmap="viridis",
                        alpha=0.95, linewidths=0, edgecolors="none")
        if cname in ["chair"]:
            ax_top.view_init(elev=elev, azim=azim)
        ax_top.set_xticks([]); ax_top.set_yticks([]); ax_top.set_zticks([])
        ax_top.set_title(cname, fontsize=12, pad=1, weight="bold")

        ax_bot = axes[1, j] if C > 1 else axes[1]
        ax_bot.scatter(xr_np[:,0], xr_np[:,1], xr_np[:,2],
                        s=s, c=xr_np[:,2], cmap="viridis",
                        alpha=0.95, linewidths=0, edgecolors="none")
        if cname in ["chair"]:
            ax_bot.view_init(elev=elev, azim=azim)
        ax_bot.set_xticks([]); ax_bot.set_yticks([]); ax_bot.set_zticks([])
        ax_bot.text2D(
            0.5, -0.12, f"W = {wd_val:.3f}",
            transform=ax_bot.transAxes,
            ha="center", va="top", fontsize=12
        )

    fig.subplots_adjust(left=0.06, right=1.1, bottom=0.08, top=1.0,
                        wspace=0.02, hspace=0.04)

    ax_top_left = axes[0, 0] if C > 1 else axes[0]
    ax_bot_left = axes[1, 0] if C > 1 else axes[1]

    pos_top = ax_top_left.get_position()
    pos_bot = ax_bot_left.get_position()

    y_top_center = 0.5 * (pos_top.y0 + pos_top.y1)
    y_bot_center = 0.5 * (pos_bot.y0 + pos_bot.y1)

    x_text = max(0.03, pos_top.x0 - 0.02)

    fig.text(x_text, y_top_center, "data",
            rotation=90, ha="center", va="center", fontsize=11, weight="bold")

    fig.text(x_text, y_bot_center, "reconstruction",
            rotation=90, ha="center", va="center", fontsize=11, weight="bold")

    fig.savefig(f"{out_dir}/{method_label}_samples_recon10.png",
                bbox_inches="tight", pad_inches=0.02)
    fig.savefig(f"{out_dir}/{method_label}_samples_recon10.pdf",
                bbox_inches="tight", pad_inches=0.02)



labels = ["airplane", "cup", "lamp", "vase", "chair"]

dataloader_airplane, dataset_airplane = load_each_dataloader_category(root_dir=dataset_path, category=labels[0], batch_size=32, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
dataloader_cup, dataset_cup = load_each_dataloader_category(root_dir=dataset_path, category=labels[1], batch_size=32, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
dataloader_lamp, dataset_lamp = load_each_dataloader_category(root_dir=dataset_path, category=labels[2], batch_size=32, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
dataloader_vase, dataset_vase = load_each_dataloader_category(root_dir=dataset_path, category=labels[3], batch_size=32, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
dataloader_chair, dataset_chair = load_each_dataloader_category(root_dir=dataset_path, category=labels[4], batch_size=32, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)


list_samples = [dataset_airplane[0], dataset_cup[0], dataset_lamp[0], dataset_vase[0], dataset_chair[1]]

wormhole.load_state_dict(torch.load(prefix_saved_model, map_location=config.device)["model_state_dict"])




plot_grid_samples_recon_compact(
    wormhole=wormhole,
    samples=list_samples,
    class_names=labels,
    out_dir="saved_reconstruction",
    method_label=method_label,
    s=6,
    elev=18, azim=125,
    max_cols=5
)

