"""Still in progress"""

import torch 
import torch.nn as nn 
import ot
import matplotlib.pyplot as plt
import numpy as np
from utils import generate_uniform_unit_sphere_projections, quantile_function, _SWGG_loss
from sw import Wasserstein_Distance


def Wasserstein_Distance(X, Y, p=2, numItermax=100000, device="cpu"):
    """
    Compute the true Wasserstein distance. Can back propagate this function
    Computational complexity: O(n^3)
    :param X: M source samples. Has shape == (M, d)
    :param Y: N target samples. Has shape == (N, d)
    :param p: Wasserstein-p
    :return: Wasserstein distance (OT cost) == M * T. It is a number
    """

    assert X.shape[1] == Y.shape[1], "source and target must have the same"

    # cost matrix between source and target. Has shape == (M, N)
    M = ot.dist(x1=X, x2=Y, metric='sqeuclidean', p=p, w=None)

    num_supports_source = X.shape[0]
    num_supports_target = Y.shape[0]

    a = torch.full((num_supports_source,), 1.0 / num_supports_source, device=device)
    b = torch.full((num_supports_target,), 1.0 / num_supports_target, device=device)

    ws = ot.emd2(a=a,
                 b=b,
                 M=M,
                 processes=1,
                 numItermax=numItermax,
                 log=False,
                 return_matrix=False,
                 center_dual=True,
                 numThreads=1,
                 check_marginals=True)

    return ws.pow(1/p)


def Sliced_Wasserstein_Distance(X, Y, num_projections=1000, projection_matrix=None, p=2, device="cuda", chunk=1000, dtype=torch.float16, return_vectors=False):
    """
    Compute Sliced Wasserstein Distance efficiently. Supports backpropagation.
    
    :param X: Batch of source measures, shape (num_supports_source, d)
    :param Y: Batch of target measures, shape (num_supports_target, d)
    :param num_projections: Number of projection directions
    :param projection_matrix: Optional precomputed projection matrix
    :param p: Wasserstein-p parameter
    :param chunk: Number of projections per chunk for memory management
    :return: Sliced Wasserstein, single number
    """

    assert X.shape[-1] == Y.shape[-1], "Source and target must have the same dimension"
    
    dims = X.shape[-1]

    if projection_matrix is not None:
        num_projections = projection_matrix.shape[0]

    if num_projections < chunk:
        chunk = num_projections
        chunk_num_projections = 1
    else:
        chunk_num_projections = num_projections // chunk

    if return_vectors:
        sum_w_p = list()
    else:
        sum_w_p = torch.tensor(0.0, device=device)

    for i in range(chunk_num_projections):
        if projection_matrix is None:
            projection_vectors = generate_uniform_unit_sphere_projections(dim=dims, num_projections=chunk, device=device)
        else:
            projection_vectors = projection_matrix[i*chunk:(i+1)*chunk]

        projection_vectors = projection_vectors.to(dtype)

        X_projection = torch.matmul(X.to(dtype), projection_vectors.t()) # (num_supports_source, chunk)
        Y_projection = torch.matmul(Y.to(dtype), projection_vectors.t()) # (num_supports_target, chunk)

        X_sorted, _ = torch.sort(X_projection, dim=0)
        Y_sorted, _ = torch.sort(Y_projection, dim=0)

        diff_quantiles = torch.abs(X_sorted - Y_sorted)
        if p == 1:
            w_1d_pow_p = diff_quantiles.mean(dim=0)
        else:
            w_1d_pow_p = diff_quantiles.pow(p).mean(dim=0)

        if return_vectors:
            sum_w_p.append(w_1d_pow_p.pow(1/p))
        else:
            sum_w_p += torch.sum(w_1d_pow_p, dim=0)

    if return_vectors:
        sum_w_p = torch.cat(sum_w_p, dim=0)
        return sum_w_p
    else:
        mean_w_p = sum_w_p / num_projections
        return mean_w_p.pow(1/p) if p != 1 else mean_w_p


def Projected_Wasserstein_Distance(X, Y, num_projections=1000, projection_matrix=None, p=2, device="cuda", chunk=1000, dtype=torch.float16):
    """
    Compute Projected Wasserstein Distance efficiently.
    
    :param X: Batch of source measures, shape (num_supports_source, d)
    :param Y: Batch of target measures, shape (num_supports_target, d)
    :param num_projections: Number of projection directions
    :param projection_matrix: Optional precomputed projection matrix
    :param p: Wasserstein-p parameter
    :param chunk: Number of projections per chunk for memory management
    :return: Sliced Wasserstein, single number
    """

    assert X.shape[-1] == Y.shape[-1], "Source and target must have the same dimension"
    
    dims = X.shape[-1]

    if projection_matrix is not None:
        num_projections = projection_matrix.shape[0]

    if num_projections < chunk:
        chunk = num_projections
        chunk_num_projections = 1
    else:
        chunk_num_projections = num_projections // chunk

    sum_w_p = torch.tensor(0.0, device=device)

    for i in range(chunk_num_projections):
        if projection_matrix is None:
            projection_vectors = generate_uniform_unit_sphere_projections(dim=dims, num_projections=chunk, device=device)
        else:
            projection_vectors = projection_matrix[i*chunk:(i+1)*chunk]

        projection_vectors = projection_vectors.to(dtype)

        X_projection = torch.matmul(X.to(dtype), projection_vectors.t()) # (num_examples, chunk)
        Y_projection = torch.matmul(Y.to(dtype), projection_vectors.t()) # (num_examples, chunk)

        X_sort_indices = torch.argsort(X_projection, dim=0)
        Y_sort_indices = torch.argsort(Y_projection, dim=0)

        diff_quantiles = (X[X_sort_indices.t()] - Y[Y_sort_indices.t()]) # (chunk, num_examples, dim)
        diff_quantiles = diff_quantiles.norm(p=p, dim=-1).pow(p).mean(dim=-1) # (chunk)

        sum_w_p += torch.sum(diff_quantiles)

    mean_w_p = sum_w_p / num_projections
    return mean_w_p.pow(1/p) if p != 1 else mean_w_p


def Energy_based_Sliced_Wasserstein(X, Y, f_type="identity", num_projections=1000, projection_matrix=None, p=2, device="cuda", chunk=1000, dtype=torch.float16):
    """
    Compute Energy based Sliced Wasserstein Distance efficiently. Supports backpropagation.
    
    :param X: Batch of source measures, shape (num_supports_source, d)
    :param Y: Batch of target measures, shape (num_supports_target, d)
    :param num_projections: Number of projection directions
    :param projection_matrix: Optional precomputed projection matrix
    :param p: Wasserstein-p parameter
    :param chunk: Number of projections per chunk for memory management
    :return: Sliced Wasserstein, single number
    """

    assert X.shape[-1] == Y.shape[-1], "Source and target must have the same dimension"
    
    dims = X.shape[-1]

    if projection_matrix is not None:
        num_projections = projection_matrix.shape[0]

    if num_projections < chunk:
        chunk = num_projections
        chunk_num_projections = 1
    else:
        chunk_num_projections = num_projections // chunk

    sum_w_p = torch.tensor(0.0, device=device)

    for i in range(chunk_num_projections):
        if projection_matrix is None:
            projection_vectors = generate_uniform_unit_sphere_projections(dim=dims, num_projections=chunk, device=device)
        else:
            projection_vectors = projection_matrix[i*chunk:(i+1)*chunk]

        projection_vectors = projection_vectors.to(dtype)

        X_projection = torch.matmul(X.to(dtype), projection_vectors.t()) # (num_supports_source, chunk)
        Y_projection = torch.matmul(Y.to(dtype), projection_vectors.t()) # (num_supports_target, chunk)

        X_sorted, _ = torch.sort(X_projection, dim=0)
        Y_sorted, _ = torch.sort(Y_projection, dim=0)

        diff_quantiles = torch.abs(X_sorted - Y_sorted)

        if p == 1:
            w_1d_pow_p = diff_quantiles.mean(dim=0)
        else:
            w_1d_pow_p = diff_quantiles.pow(p).mean(dim=0)

        if(f_type == "exp"):
            weights = torch.softmax(w_1d_pow_p, dim=0)
        elif(f_type == "identity"):
            eps = 1e-6
            weights = (w_1d_pow_p + eps) / torch.sum((w_1d_pow_p + eps), dim=0, keepdim=True)
        elif (f_type == "poly"):
            eps = 1e-6
            weights = (w_1d_pow_p**rho + eps) / torch.sum((w_1d_pow_p**rho + eps), dim=0, keepdim=True)

        sum_w_p += torch.sum(weights *  w_1d_pow_p, dim=0)

    return sum_w_p.pow(1/p) if p != 1 else sum_w_p


def Max_Sliced_Wasserstein_Distance(X, Y, require_optimize=False, lr=1e-2, num_iter=1000, num_projections=1000, projection_matrix=None, p=2, dtype=torch.float16, device="cuda", chunk=1000):
    """
    Compute Max Sliced Wasserstein Distance efficiently. Supports backpropagation.
    
    :param X: Batch of source measures, shape (num_supports_source, d)
    :param Y: Batch of target measures, shape (num_supports_target, d)
    :param require_optimize: if True then use Gradient Ascent to find maximum sliced wasserstein wrt projection vector, False then just compute sliced wasserstein 'num_projections' times then pick the largest one
    :param num_projections: Number of projection directions
    :param projection_matrix: Optional precomputed projection vectors
    :param p: Wasserstein-p parameter
    :param chunk: Number of projections per chunk for memory management
    :return: Tensor of shape (num_batches,)
    """
    if require_optimize is True:

        X_noopt = X.detach()
        Y_noopt = Y.detach()

        projection_vector = generate_uniform_unit_sphere_projections(X_noopt.shape[-1], requires_grad=True, num_projections=1, dtype=dtype, device=device) # shape == (1, dim)
        optimizer = torch.optim.SGD([projection_vector], lr=lr)
        list_loss = list()
        max_sw = 0
        optimal_projection = None
        for i in range(num_iter):
            optimizer.zero_grad()

            loss = -1 * Sliced_Wasserstein_Distance(X_noopt, Y_noopt, projection_matrix=projection_vector, p=p, device=device, chunk=chunk, dtype=device, return_vectors=True)
            list_loss.append(loss.item())

            if -loss.item() > max_sw:
                max_sw = -loss.item()
                optimal_projection = projection_vector.clone().detach()

            loss.backward()
            optimizer.step()

            with torch.no_grad(): 
                projection_vector.div_(projection_vector.norm(dim=1, keepdim=True).clamp_min_(1e-12))
        
        optimal_projection.div_(optimal_projection.norm(dim=1, keepdim=True).clamp_min_(1e-12))
        max_sw_opt = Sliced_Wasserstein_Distance(X, Y, projection_matrix=optimal_projection, p=p, device=device, chunk=chunk, dtype=dtype, return_vectors=False)
        return max_sw_opt, list_loss, optimal_projection
    
    else: # just compute Sliced Wasserstein 'num_projections' times and choose the largest value
        batch_sw_value = Sliced_Wasserstein_Distance(X, Y, num_projections=num_projections, projection_matrix=projection_matrix, p=p, device=device, chunk=chunk, dtype=dtype, return_vectors=True) 
        # batch_sw_value.shape == (num_projections)
        return torch.max(batch_sw_value, dim=0)[0]


def Min_SWGG(X, Y, lr=1e-2, num_iter=10000, p=2, s=20, std=0.5, dtype=torch.float16, device="cuda"):
    X_noopt = X.detach()
    Y_noopt = Y.detach()
    projection_vector = generate_uniform_unit_sphere_projections(X_noopt.shape[-1], requires_grad=True, num_projections=1, dtype=dtype, device=device) # shape == (1, dim)
    # optimizer = torch.optim.SGD([projection_vector], lr=lr)
    optimizer = torch.optim.Adam([projection_vector], lr=lr)
    list_loss = list()
    min_swgg = float("Inf")
    optimal_projection = None
    for i in range(num_iter):
        optimizer.zero_grad()
        loss = _SWGG_loss(X_noopt, Y_noopt, theta=projection_vector, device=device)
        list_loss.append(loss.pow(1/p).item())
        if min_swgg > loss.pow(1/p).item():
            min_swgg = loss.pow(1/p).item()
            optimal_projection = projection_vector.clone().detach()
        loss.backward()
        optimizer.step()
        with torch.no_grad(): 
            projection_vector.div_(projection_vector.norm(dim=1, keepdim=True).clamp_min_(1e-12))

    optimal_projection.div_(optimal_projection.norm(dim=1, keepdim=True).clamp_min_(1e-12))
    min_swgg = _SWGG_loss(X, Y, theta=optimal_projection, device=device)
    return min_swgg, list_loss, optimal_projection


def Expected_Sliced_Transport(X, Y, num_projections=1000, projection_matrix=None, p=2, device="cuda", chunk=1000, dtype=torch.float16):
    """
    Compute Expected Sliced Transport Distance efficiently.
    
    :param X: Batch of source measures, shape (num_supports_source, d)
    :param Y: Batch of target measures, shape (num_supports_target, d)
    :param num_projections: Number of projection directions
    :param projection_matrix: Optional precomputed projection matrix
    :param p: Wasserstein-p parameter
    :param chunk: Number of projections per chunk for memory management
    :return: Sliced Wasserstein, single number
    """

    assert X.shape[-1] == Y.shape[-1], "Source and target must have the same dimension"
    
    dims = X.shape[-1]

    if projection_matrix is not None:
        num_projections = projection_matrix.shape[0]

    if num_projections < chunk:
        chunk = num_projections
        chunk_num_projections = 1
    else:
        chunk_num_projections = num_projections // chunk

    sum_w_p = torch.tensor(0.0, device=device)

    for i in range(chunk_num_projections):
        if projection_matrix is None:
            projection_vectors = generate_uniform_unit_sphere_projections(dim=dims, num_projections=chunk, device=device)
        else:
            projection_vectors = projection_matrix[i*chunk:(i+1)*chunk]

        projection_vectors = projection_vectors.to(dtype)

        X_projection = torch.matmul(X.to(dtype), projection_vectors.t()) # (num_examples, chunk)
        Y_projection = torch.matmul(Y.to(dtype), projection_vectors.t()) # (num_examples, chunk)

        X_sort_indices = torch.argsort(X_projection, dim=0)
        Y_sort_indices = torch.argsort(Y_projection, dim=0)

        diff_quantiles = (X[X_sort_indices.t()] - Y[Y_sort_indices.t()]) # (chunk, num_examples, dim)
        pwd_pow_p = diff_quantiles.norm(p=p, dim=-1).pow(p).mean(dim=-1) # (chunk)

        tau = 10.0
        weights = torch.softmax(-1 * tau * pwd_pow_p, dim=0)

        sum_w_p += torch.sum(weights * pwd_pow_p)

    return sum_w_p.pow(1/p) if p != 1 else sum_w_p


# if __name__ == '__main__':

#     device = "cpu"
#     dtype = torch.float32 
#     num_ex = 5
#     dims = 20


#     X = torch.randn(num_ex, dims, device=device, dtype=dtype)
#     Y = torch.randn(num_ex, dims, device=device, dtype=dtype)

#     sw = Sliced_Wasserstein_Distance(X, Y, num_projections=1000, device=device, dtype=dtype)
#     pwd = Projected_Wasserstein_Distance(X, Y, num_projections=1000, device=device, dtype=dtype)
#     ebsw = Energy_based_Sliced_Wasserstein(X, Y, num_projections=1000, device=device, dtype=dtype)
#     ws = Wasserstein_Distance(X, Y)
#     est = Expected_Sliced_Transport(X, Y, num_projections=1000, device=device, dtype=dtype)

#     max_sw = Max_Sliced_Wasserstein_Distance(X, Y, require_optimize=True, lr=1e-1, num_iter=1000, device=device, dtype=dtype)[0]
#     min_swgg = Min_SWGG(X, Y, lr=5e-2, num_iter=1000, s=20, std=0.5, device=device, dtype=dtype)[0]

#     print(sw, ebsw, max_sw, ws, min_swgg, pwd, est)


import torch

# === your metric funcs must be imported/defined above ===
# SW, PWD, EBSW, EST, Wasserstein_Distance, Max_Sliced_Wasserstein_Distance, Min_SWGG

def check_grad(name, fn, X, Y, **kw):
    out = fn(X, Y, **kw)
    # nếu trả tuple -> lấy phần tử Tensor đầu tiên
    if isinstance(out, tuple):
        out = next((v for v in out if isinstance(v, torch.Tensor)), None)
    if not isinstance(out, torch.Tensor):
        print(f"[{name}] -> NO_TENSOR (không thể backprop)")
        return
    loss = out.mean()
    X.grad = None; Y.grad = None
    try:
        loss.backward()
        gx = 0.0 if X.grad is None else X.grad.norm().item()
        gy = 0.0 if Y.grad is None else Y.grad.norm().item()
        print(f"[{name}] req_grad={out.requires_grad}  ||dX||={gx:.3e}  ||dY||={gy:.3e}")
    except Exception as e:
        print(f"[{name}] BACKWARD ERROR: {e}")

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype  = torch.float32
    N, D   = 16, 8

    # tạo một bộ X,Y dùng lại (mỗi lần backward reset grad)
    X = torch.randn(N, D, device=device, dtype=dtype, requires_grad=True)
    Y = torch.randn(N, D, device=device, dtype=dtype, requires_grad=True)

    common = dict(device=device, dtype=dtype)

    check_grad("SW",
        Sliced_Wasserstein_Distance, X, Y,
        num_projections=64, chunk=64, **common)

    check_grad("PWD",
        Projected_Wasserstein_Distance, X, Y,
        num_projections=64, chunk=64, **common)

    check_grad("EBSW",
        Energy_based_Sliced_Wasserstein, X, Y,
        num_projections=64, chunk=64, **common)

    check_grad("EST",
        Expected_Sliced_Transport, X, Y,
        num_projections=64, chunk=64, **common)

    check_grad("Wasserstein (EMD2)",
        Wasserstein_Distance, X, Y, p=2, numItermax=1000, device=device)

    check_grad("MaxSW (random max)",
        Max_Sliced_Wasserstein_Distance, X, Y,
        require_optimize=False, num_projections=64, chunk=64, **common)

    check_grad("MaxSW (optimize v*)",
        Max_Sliced_Wasserstein_Distance, X, Y,
        require_optimize=True, lr=1e-1, num_iter=50, num_projections=1, chunk=1, **common)

    check_grad("Min-SWGG",
        Min_SWGG, X, Y, lr=5e-2, num_iter=50, **common)
