# train_fast_wormhole_modelnet40.py
import os, json, time, math
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import trimesh

from wassersteinwormhole_pytorch.wormhole import fast_Wormhole
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.default_config import DefaultConfig

from sw import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Max_Sliced_Wasserstein_Distance,
    Min_SWGG,
    Expected_Sliced_Transport,
)
from utils import generate_uniform_unit_sphere_projections

def _minmax_neg1_pos1(pc):
    mn = pc.min(axis=0, keepdims=True)
    mx = pc.max(axis=0, keepdims=True)
    denom = np.clip(mx - mn, 1e-12, None)
    return ((pc - mn) / denom * 2.0 - 1.0).astype(np.float32)

def _unit_sphere(pc):
    pc = pc - pc.mean(axis=0, keepdims=True)
    scale = np.linalg.norm(pc, axis=1).max()
    return (pc / max(scale, 1e-12)).astype(np.float32)

def _sample_surface(mesh, n_points):
    pts, _ = trimesh.sample.sample_surface(mesh, n_points)
    return pts.astype(np.float32)

def _load_pc(path, n_points):
    m = trimesh.load(path, process=False)
    if isinstance(m, trimesh.Trimesh):
        return _sample_surface(m, n_points)
    geoms = [g for g in getattr(m, "geometry", {}).values()]
    if not geoms:
        raise RuntimeError(f"Unsupported file: {path}")
    merged = trimesh.util.concatenate(geoms)
    return _sample_surface(merged, n_points)

class ModelNet40PointCloud(Dataset):
    def __init__(self, root_dir, split="train", n_points=2048, normalize="minmax", labels=None):
        import glob
        self.root_dir  = root_dir
        self.split     = split
        self.n_points  = n_points
        self.normalize = normalize
        all_classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        lower_map = {c.lower(): c for c in all_classes}
        if labels is not None:
            selected = []
            for lb in labels:
                if isinstance(lb, int) and 0 <= lb < len(all_classes):
                    selected.append(all_classes[lb])
                elif isinstance(lb, str) and lb.lower() in lower_map:
                    selected.append(lower_map[lb.lower()])
            classes = [c for c in all_classes if c in set(selected)] or all_classes
        else:
            classes = all_classes
        self.class_to_idx = {c: i for i, c in enumerate(all_classes)}
        self.samples = []
        for cls in classes:
            folder = os.path.join(root_dir, cls, split)
            files = sorted(glob.glob(os.path.join(folder, "*.off")) + glob.glob(os.path.join(folder, "*.ply")))
            self.samples.extend([(f, self.class_to_idx[cls]) for f in files])
        if len(self.samples) == 0:
            raise RuntimeError(f"No files found in {root_dir} split={split}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        f, y = self.samples[idx]
        pc = _load_pc(f, self.n_points)
        if self.normalize == "minmax":
            pc = _minmax_neg1_pos1(pc)
        elif self.normalize == "unit":
            pc = _unit_sphere(pc)
        else:
            pc = pc.astype(np.float32)
        return torch.from_numpy(pc), torch.tensor(y, dtype=torch.long)

def main():
    estimate_alpha_general = True
    root_modelnet = "data/ModelNet40"
    labels_subset = None
    batch_size = 16
    n_points   = 2048

    if estimate_alpha_general: 
        prefix_run_dir = "saved_fast_wormhole/optimal_alpha_general" 
    else: 
        prefix_run_dir = "saved_fast_wormhole/optimal_alpha_simplex" 
    os.makedirs(prefix_run_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = DefaultConfig(
        device=device,
        batch_size=batch_size, dtype=torch.float32, coeff_dec=0.1,
        emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.0,
        lr=1e-4, epochs=0, decay_steps=0
    )

    train_set = ModelNet40PointCloud(
        root_dir=root_modelnet, split="train",
        n_points=n_points, normalize="minmax",
        labels=labels_subset
    )
    pin = torch.cuda.is_available()
    dataloader = DataLoader(
        train_set, batch_size=config.batch_size, shuffle=True,
        num_workers=8, pin_memory=pin, drop_last=False, persistent_workers=True
    )

    config.n_samples = len(train_set)
    config.n_points  = n_points
    config.input_dim = 3

    TARGET_STEPS = 10000
    steps_per_epoch = math.ceil(config.n_samples / config.batch_size)
    config.epochs = max(2000, math.ceil(TARGET_STEPS / steps_per_epoch))

    projection_matrix = generate_uniform_unit_sphere_projections(
        dim=config.input_dim, requires_grad=False, num_projections=100,
        dtype=config.dtype, device=config.device
    )

    def W(x, y):
        return Wasserstein_Distance(x, y, numItermax=10000, device=config.device)

    def SW(x, y):
        return Sliced_Wasserstein_Distance(
            x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
        )

    def PWD(x, y):
        return Projected_Wasserstein_Distance(
            x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
        )

    def EBSW(x, y):
        return Energy_based_Sliced_Wasserstein(
            x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
        )

    def EST(x, y):
        return Expected_Sliced_Transport(
            x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
        )

    def MinSWGG_fn(x, y):
        return Min_SWGG(
            x, y, lr=5e-2, num_iter=5, s=10, std=0.5, device=config.device, dtype=config.dtype
        )[0]

    def MaxSW_fn(x, y):
        return Max_Sliced_Wasserstein_Distance(
            x, y, require_optimize=True, lr=1e-2, num_iter=5, device=config.device, dtype=config.dtype
        )[0]

    w_fn = W
    GROUPS = [
        ("sw_pwd",                        [SW, PWD]),
        ("ebsw_est",                      [EBSW, EST]),
        ("maxsw_minswgg",                 [MaxSW_fn, MinSWGG_fn]),
        ("sw_pwd_ebsw_est",               [SW, PWD, EBSW, EST]),
        ("sw_pwd_ebsw_est_maxsw_minswgg", [SW, PWD, EBSW, EST, MaxSW_fn, MinSWGG_fn]),
    ]

    def take_alpha_samples(loader, n):
        xs, c = [], 0
        for xb, _ in loader:
            xs.append(xb)
            c += xb.shape[0]
            if c >= n:
                break
        return torch.cat(xs, dim=0)[:n].to(config.device).to(config.dtype)

    ALPHA_SAMPLES = min(10, config.n_samples)

    for group_name, metric_list in GROUPS:
        model_g = TransformerAutoencoder(config=config, seq_len=config.n_points, inp_dim=config.input_dim).to(config.device)
        run_dir = os.path.join(prefix_run_dir, group_name)
        os.makedirs(run_dir, exist_ok=True)
        metric_names = [fn.__name__ for fn in metric_list]

        fast_wh = fast_Wormhole(
            estimate_alpha_general=estimate_alpha_general,
            transformer=model_g,
            config=config,
            metric_funcs=metric_list,
            ground_truth_func=w_fn,
            run_dir=run_dir,
            metric_names=metric_names,
            compute_stats=True,
            save_best=True
        )

        optimizer_g = torch.optim.Adam(model_g.parameters(), lr=config.lr)
        gamma = (0.1) ** (1.0 / config.epochs)
        scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=gamma)

        t0 = time.time()
        Xalpha = take_alpha_samples(dataloader, ALPHA_SAMPLES)
        fast_wh.estimate_alpha(samples=Xalpha)
        est_alpha_sec = time.time() - t0

        t1 = time.time()
        fast_wh.train_model(
            dataloader=dataloader,
            optimizer=optimizer_g,
            scheduler=scheduler_g,
            epochs=config.epochs,
            save_every=50,
            verbose=False
        )
        train_sec = time.time() - t1
        total_sec = est_alpha_sec + train_sec

        alphas_np = fast_wh.alphas.detach().cpu().numpy()
        np.save(os.path.join(run_dir, "alphas.npy"), alphas_np)
        with open(os.path.join(run_dir, "alphas.json"), "w") as f:
            json.dump({name: float(val) for name, val in zip(metric_names, alphas_np)}, f, indent=2)

        ckpt_path = os.path.join(run_dir, f"lr{config.lr}_epoch{config.epochs}.pth")
        torch.save({
            "model_state_dict": model_g.state_dict(),
            "optimizer_state_dict": optimizer_g.state_dict(),
            "epoch": int(config.epochs),
            "estimate_alpha_sec": float(est_alpha_sec),
            "train_time_sec": float(train_sec),
            "total_time_sec": float(total_sec),
            "alphas": alphas_np.tolist(),
            "optimal_alpha": alphas_np,
            "metric_names": metric_names,
        }, ckpt_path)

        time_stats = {
            "group": group_name,
            "estimate_alpha_sec": round(float(est_alpha_sec), 3),
            "train_time_sec": round(float(train_sec), 3),
            "total_time_sec": round(float(total_sec), 3),
            "epochs": int(config.epochs),
            "lr": float(config.lr),
            "batch_size": int(config.batch_size),
            "alphas": {name: float(val) for name, val in zip(metric_names, alphas_np)}
        }
        with open(os.path.join(run_dir, "time_stats.json"), "w") as f:
            json.dump(time_stats, f, indent=2)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
