from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.default_config import DefaultConfig

import os, json, time, math
import torch
from torch.utils.data import TensorDataset, DataLoader


samples_per_class = 50
num_classes = 10
num_samples_training = samples_per_class * num_classes
dataset_path   = f"preprocessed_dataset/pointcloud/train/num_samples_{num_samples_training}"
prefix_run_dir = f"saved_embeddings"
run_dir        = os.path.join(prefix_run_dir, "wormhole")
os.makedirs(run_dir, exist_ok=True)

config = DefaultConfig(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    batch_size=10, dtype=torch.float32, coeff_dec=0.1,
    emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.1,
    lr=1e-4, epochs=0, decay_steps=0
)

X_train = torch.load(os.path.join(dataset_path, "samples.pt")).to(config.device).to(config.dtype)

print(f"=> Loaded {X_train.shape[0]} point clouds with {X_train.shape[1]} points each from {dataset_path}")

config.n_samples = X_train.shape[0]
config.n_points  = X_train.shape[1]
config.input_dim = X_train.shape[2]

dataset   = TensorDataset(X_train)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=False)

TARGET_STEPS = 10000
steps_per_epoch = math.ceil(len(dataset) / config.batch_size)
config.epochs = max(2000, math.ceil(TARGET_STEPS / steps_per_epoch))
print(f"=> steps/epoch ~ {steps_per_epoch}, target steps={TARGET_STEPS}, epochs set to {config.epochs}")


model = TransformerAutoencoder(config=config, seq_len=config.n_points, inp_dim=config.input_dim).to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

final_factor = 0.1
gamma = final_factor ** (1.0 / max(1, config.epochs))
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

wormhole = Wormhole(transformer=model, config=config, run_dir=run_dir)


print(f"Start training baseline Wormhole (num_samples = {num_samples_training}) ---")
t0 = time.time()
wormhole.train_model(
    dataloader=dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=config.epochs,
    save_every=50,
    verbose=True
)
train_sec = time.time() - t0
print(f"Finished training in {train_sec/60:.2f} minutes ---")

ckpt_name = f"baseline_lr{config.lr}_epoch{config.epochs}.pth"
ckpt_path = os.path.join(run_dir, ckpt_name)
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": int(config.epochs),
    "train_time_sec": float(train_sec),
}, ckpt_path)
print(f"=> Saved checkpoint at {ckpt_path}")

stats = {
    "experiment": "baseline_wormhole",
    "dataset_path": dataset_path,
    "run_dir": run_dir,
    "num_samples": int(num_samples_training),
    "epochs": int(config.epochs),
    "lr": float(config.lr),
    "batch_size": int(config.batch_size),
    "train_time_sec": round(float(train_sec), 3),
    "train_time_min": round(float(train_sec) / 60.0, 3),
    "device": str(config.device),
    "input_dim": int(config.input_dim),
    "n_points": int(config.n_points),
    "emb_dim": int(config.emb_dim),
    "num_heads": int(config.num_heads),
    "num_layers": int(config.num_layers),
    "mlp_dim": int(config.mlp_dim),
    "attention_dropout_rate": float(config.attention_dropout_rate),
    "coeff_dec": float(config.coeff_dec),
    "target_steps": TARGET_STEPS,
    "steps_per_epoch": steps_per_epoch,
    "scheduler": "ExponentialLR",
    "gamma": gamma,
}
with open(os.path.join(run_dir, "time_stats.json"), "w") as f:
    json.dump(stats, f, indent=2)
print(f"=> Saved stats to {os.path.join(run_dir, 'time_stats.json')}")

if torch.cuda.is_available():
    torch.cuda.empty_cache()
