import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sw2 import Wasserstein_Distance
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.default_config import DefaultConfig
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import time
import fnmatch

def batched_wormhole_and_ws(wormhole, pcs_X, pcs_Y, batch_size=256, device="cuda"):
    dataset = TensorDataset(pcs_X, pcs_Y)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    wormhole.eval()
    ws_list = []
    enc_list = []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            enc = wormhole.compute_wormhole(x, y)
            enc_list.append(enc.cpu().numpy())
            ws_batch = []
            for x_pc, y_pc in zip(x, y):
                ws = Wasserstein_Distance(x_pc, y_pc, device=device)
                ws_batch.append(ws.cpu().numpy())
            ws_list.append(np.stack(ws_batch))
    ws_arr = np.concatenate(ws_list, axis=0)
    enc_arr = np.concatenate(enc_list, axis=0)

    return enc_arr, ws_arr


def evaluate_wormhole(num_samples, epoch_test=1000, num_pairs_test=10000, save_prefix="wormhole_run", dataset_root="compare_wormhole", batch_size=256):
    SAVE_PATH = f"{save_prefix}/num_{num_samples}"
    print(f"saved path: {SAVE_PATH}")
    DATA_PATH = os.path.join(dataset_root, "train", f"num_samples_{num_samples}", "samples.pt")
    TEST_DIR = os.path.join(dataset_root, "test", f"num_pairs_{num_pairs_test}")

    wormhole_npy_file = os.path.join(SAVE_PATH, "wormhole_np_200.npy")
    ws_npy_file = os.path.join(SAVE_PATH, "ws_np_200.npy")

    if not os.path.isfile(wormhole_npy_file):

        if not os.path.exists(DATA_PATH):
            raise FileNotFoundError(f"Dataset not found: {DATA_PATH}")
        X_train = torch.load(DATA_PATH)
        print(f"=> Loaded {X_train.shape[0]} point clouds from {DATA_PATH}")

        pcs_X = torch.load(os.path.join(TEST_DIR, "pcs1.pt"))
        pcs_Y = torch.load(os.path.join(TEST_DIR, "pcs2.pt"))
        print(f"=> Loaded {len(pcs_X)} test pairs from {TEST_DIR}")

        config = DefaultConfig(
            n_points=100,
            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
            batch_size=32, dtype=torch.float32,
            emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.1,
            input_dim=2, lr=1e-4, epochs=epoch_test, decay_steps=200
        )
        config.n_points = pcs_X.shape[1]
        config.n_samples = pcs_X.shape[0]
        config.input_dim = pcs_X.shape[2]

        checkpoint_path = os.path.join(SAVE_PATH, f"num{num_samples}_lr0.0001_epoch{epoch_test}.pth")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        else:
            print(checkpoint_path)

        print("=> Computing wormhole latent and true Wasserstein distances in batches...")
        start = time.time()
        model = TransformerAutoencoder(
            config=config, seq_len=config.n_points, inp_dim=config.input_dim
        ).to(config.device)
        wormhole = Wormhole(transformer=model, config=config, run_dir=SAVE_PATH)
        state_dict = torch.load(checkpoint_path, map_location=config.device)["model_state_dict"]
        wormhole.load_state_dict(state_dict)

        pcs_X = pcs_X.to(config.device)
        pcs_Y = pcs_Y.to(config.device)

        enc_L2_np, ws_np = batched_wormhole_and_ws(wormhole, pcs_X, pcs_Y, batch_size=batch_size, device=config.device)

        mask = ~((ws_np == 0) & (enc_L2_np == 0))
        ws_np = ws_np[mask]
        enc_L2_np = enc_L2_np[mask]
        
        elapsed = time.time() - start

        os.makedirs(SAVE_PATH, exist_ok=True)
        with open(f"{SAVE_PATH}/compute_wormhole_time.txt", "w") as f:
            f.write(f"compute_both_time_sec: {elapsed:.4f}\n")

        np.save(os.path.join(SAVE_PATH, f"ws_np_{num_samples}.npy"), ws_np)
        np.save(os.path.join(SAVE_PATH, f"wormhole_np_{num_samples}.npy"), enc_L2_np)
    
    else:
        ws_np = np.load(ws_npy_file, allow_pickle=True)
        enc_L2_np = np.load(wormhole_npy_file, allow_pickle=True)
    
    r2 = r2_score(ws_np, enc_L2_np)
    mse = mean_squared_error(ws_np, enc_L2_np)
    mae = mean_absolute_error(ws_np, enc_L2_np)
    print(f"R²: {r2:.4f}, MSE: {mse:.4f}, MAE: {mae:.4f}")

if __name__ == "__main__":
    num_samples_list = [10, 50, 100, 200]
    epoch_test = 2000
    for num_samples in num_samples_list:
        print(f"\n===== EVALUATING num_samples = {num_samples} =====")
        evaluate_wormhole(
            num_samples=num_samples,
            epoch_test=epoch_test,
            num_pairs_test=10000,
            save_prefix="saved_compare_wormhole/pointcloud/wormhole",
            dataset_root="preprocessed_dataset/point_cloud",
            batch_size=256
        )