import os
import copy
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

import wandb
from tqdm import tqdm
import gc

from utils.schedulers import CosineAnnealingWarmUpRestarts
from utils.mmd import mmd_linear, mmd_linear_bootstrap_test
from .hills import *


def wasserstein_1d(x: torch.Tensor, y: torch.Tensor):
    """
    x, y: shape (N,) or (N,1)
    returns scalar tensor (W1 distance)
    """
    x = x.flatten().sort().values
    y = y.flatten().sort().values
    n = min(len(x), len(y))
    x = x[:n]
    y = y[:n]
    return torch.mean(torch.abs(x - y))


def sliced_wasserstein_2d(X: torch.Tensor, Y: torch.Tensor, num_projections: int = 256, robust='mean', seed: int = 0):
    """
    X, Y: shape (N, 2)  (columns: [in_degree, out_degree])
    num_projections: number of random directions on the unit circle
    robust: 'mean' or 'median' aggregator over projections
    """
    assert X.shape[1] == 2 and Y.shape[1] == 2, "X, Y must be (N, 2)"
    g = torch.Generator(device=X.device).manual_seed(seed)
    n = min(len(X), len(Y))
    X = X[:n]
    Y = Y[:n]
    dists = []
    theta = torch.rand(num_projections, generator=g, device=X.device) * 2 * torch.pi
    U = torch.stack([torch.cos(theta), torch.sin(theta)], dim=1)  # (P,2)
    for u in U:
        x_proj = (X @ u)     # (n,)
        y_proj = (Y @ u)     # (n,)
        dists.append(wasserstein_1d(x_proj, y_proj))
    dists = torch.stack(dists)  # (P,)

    if robust == 'median':
        return dists.median()
    return dists.mean()


def evaluate_with_mmd(model, test_loader, MMD_test_N, boot_iter, tail_cut, device, model_name="VAE"):
    model.eval()
    gen_samples = []
    test_data = []

    with torch.no_grad():
        for batch_idx, x_batch in enumerate(test_loader):
            x_batch = x_batch.to(device)
            if model.model_name == "AE":
                z = model.encode(x_batch)
            else:
                z, *_ = model.encode(x_batch)
            x_gen = model.decode(z)
            gen_samples.append(x_gen.cpu())
            test_data.append(x_batch.cpu())

    gen_data = torch.cat(gen_samples, dim=0).to(device)
    test_data = torch.cat(test_data, dim=0).to(device)

    full_result = mmd_linear_bootstrap_test(
        gen_data[:MMD_test_N], test_data[:MMD_test_N],
        device=device, iteration=boot_iter
    )

    # tail filtering
    test_data_norm = torch.norm(test_data, dim=1)
    gen_data_norm = torch.norm(gen_data, dim=1)

    tail_test_data = test_data[test_data_norm > tail_cut]
    tail_gen_data = gen_data[gen_data_norm > tail_cut]

    right_test_data = tail_test_data[tail_test_data[:, 0] > 0]

    right_gen_data = tail_gen_data[tail_gen_data[:, 0] > 0]
    right_result = mmd_linear_bootstrap_test(
        right_gen_data, right_test_data,
        device=device, iteration=boot_iter
    )
    print(f"[{model_name}] p-value (Full): {full_result[1]:.4f}")
    if right_result[1] is not None:
        print(f"[{model_name}] p-value (Right Tail): {right_result[1]:.4f}")
    else:
        print(f"[{model_name}] p-value (Right Tail): None")


    return {
        "Metrics/full": full_result[1],
        "Metrics/right": right_result[1]
    }

def train_data(
        model, device,
        train_dataloader, val_dataloader, test_dataloader,
        epochs, lr, eps, weight_decay, batch_size, seed, args, patience=100,
        swd_projections=512):
    

    test_data = []
    for data in test_dataloader:
        test_data.append(data)
    
    

    test_data = torch.cat(test_data,dim=0)
    print(test_data.shape)
    
    vals = test_data.norm(dim=1)  
    tailcut = torch.quantile(vals, 0.95)
    test_data = test_data[vals >= tailcut,:]
    
    test_data_pos1 = test_data[:,0]
    test_data_pos2 = test_data[:,1]

    test_data_pos1 = test_data_pos1[test_data_pos1>0]
    test_data_pos2 = test_data_pos2[test_data_pos2>0]
    true_alpha1 = fit_hill(x=test_data_pos1)[0]
    true_alpha2 = fit_hill(x=test_data_pos2)[0]
    
    

    wandb.init(
        project="iclr2026_simul_try6_mdim=4_layer=512",
        group="estimating values",
        config={
            "model": getattr(model, "model_name", type(model).__name__),
            "epochs": epochs,
            "learning_rate": lr,
            "nu": getattr(model, "nu", None),
            "batch_size": batch_size,
            "dataset": '2dgraph',
            "seed": seed,
            "layers": args.num_layers,
            "m_dim": args.m_dim,
            "reg_weight": args.reg_weight,
            "optimizer": "Adamax",
            "hill's alpha": f"({true_alpha1},{true_alpha2})"
            },
        name = f"{model.model_name}_lr:{lr}_layers:{args.num_layers}_w:{args.reg_weight}_seed:{args.seed}"   
    )
    wandb.watch(model, log="gradients", log_freq=100)

    model = model.to(device)
    opt = optim.Adamax(model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay)

    scheduler = CosineAnnealingWarmUpRestarts(
        optimizer=opt,
        T_0=10, T_mult=1, eta_max=lr, T_up=25, gamma=0.5
    )

    # --- Early stopping state ---
    best_loss = float("inf")
    best_model = copy.deepcopy(model)
    model_count = 0
    model_stop = False

    global_step = 0 


    for epoch in tqdm(range(epochs), desc="Training"):
        if model_stop:
            break

        # -------- Train --------
        model.train()
        epoch_train_loss = 0.0
        for data in train_dataloader:
            data = data.to(device)

            opt.zero_grad()
            recon_loss, reg_loss, train_loss = model(data)
            if torch.isnan(train_loss):
                print(f"NaN occurs! Current epoch : {epoch}")
                model_stop = True
                break

            train_loss.backward()
            opt.step()

            wandb.log({
                "Train/Total": float(train_loss),
                "Train/Recon": float(recon_loss),
                "Train/Reg": float(reg_loss),
            }, step=global_step)

            epoch_train_loss += float(train_loss)
            global_step += 1

        scheduler.step()

        # -------- Validation --------
        model.eval()
        total_val_loss = 0.0
        total_val_recon = 0.0
        total_val_reg = 0.0
        with torch.no_grad():
            for data in val_dataloader:
                data = data.to(device)
                recon_loss, reg_loss, validation_loss = model(data)
                total_val_recon += float(recon_loss)
                total_val_reg += float(reg_loss)
                total_val_loss += float(validation_loss)

        N_val = max(1, len(val_dataloader))
        avg_val_loss = total_val_loss / N_val
        avg_val_recon = total_val_recon / N_val
        avg_val_reg = total_val_reg / N_val


        wandb.log({
            "Val/Total": avg_val_loss,
            "Val/Recon": avg_val_recon,
            "Val/Reg": avg_val_reg,
            "Epoch": epoch,
        }, step=global_step)

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_model = copy.deepcopy(model)
            model_count = 0
        else:
            model_count += 1
            if model_count >= patience:
                model_stop = True
                print(f"{getattr(model, 'model_name', 'Model')} stopped early at epoch {epoch}")


        total_test_loss = 0.0
        total_test_recon = 0.0
        total_test_reg = 0.0

        best_model.eval()
        X_recons = []
        total_SWD = 0.0
        with torch.no_grad():
            for data in test_dataloader:
                data = data.to(device)
                recon_loss, reg_loss, test_loss = best_model(data)

                total_test_recon += float(recon_loss)
                total_test_reg += float(reg_loss)
                total_test_loss += float(test_loss)
                X_recon = best_model.recon_data(data).detach().cpu()
                X_recons.append(X_recon)
                total_SWD += sliced_wasserstein_2d(data.detach().cpu(), X_recon, num_projections=swd_projections, robust='mean', seed=seed)
        X_recons = torch.cat(X_recons, dim=0)
        

        N_test = max(1, len(test_dataloader))
        avg_test_loss = total_test_loss / N_test
        avg_test_recon = total_test_recon / N_test
        avg_test_reg = total_test_reg / N_test
        avg_swd = total_SWD / N_test
                
        mmd_dict = evaluate_with_mmd(best_model, test_dataloader, 50000, 999, tailcut, device)
        
        
        # TAIL!!
        X_recons = model.recon_data(test_data.to(device)).detach().cpu()
        pos_idx_1 = X_recons[:,0] > 0
        pos_idx_2 = X_recons[:,1] > 0
        estimated_hill_1 = fit_hill(X_recons[pos_idx_1,0])[0]
        estimated_hill_2 = fit_hill(X_recons[pos_idx_2,1])[0]
        tail_SWD = sliced_wasserstein_2d(test_data.detach().cpu(), X_recons, num_projections=swd_projections, robust='mean', seed=seed)

        
        
        log_dict = {
            "Test/Total": avg_test_loss,
            "Test/Recon": avg_test_recon,
            "Test/Reg": avg_test_reg,
            "Metrics/W_swd": avg_swd,
            "Metrics/Hill_1_diff": np.abs(estimated_hill_1-true_alpha1),
            "Metrics/Hill_2_diff": np.abs(estimated_hill_2-true_alpha2),
            "Metrics/tail_SWD": tail_SWD,
        }

        wandb.log(log_dict, step=global_step)
        wandb.log(mmd_dict, step=global_step)
        print(mmd_dict)

        gc.collect()
        torch.cuda.empty_cache()

    wandb.finish()
    return best_model