import torch 
import torch.nn as nn 
import ot
import matplotlib.pyplot as plt
import numpy as np


import numpy as np
import cvxpy as cp

def optimal_alpha_simplex(X, y, ridge=0.0, solver="OSQP"):
    """
    Giải:   min ||X a - y||^2 + ridge * ||a||^2
            s.t. a >= 0, sum(a) = 1
    X: (n,d), y: (n,)
    ridge: λ>=0 (0: không regularize; >0: giúp ổn định khi đa cộng tuyến)
    """
    X = np.asarray(X, dtype=float)
    y = np.asarray(y, dtype=float).reshape(-1)

    n, d = X.shape
    a = cp.Variable(d)

    obj = cp.sum_squares(X @ a - y)
    if ridge and ridge > 0:
        obj += ridge * cp.sum_squares(a)

    constraints = [a >= 0, cp.sum(a) == 1]
    prob = cp.Problem(cp.Minimize(obj), constraints)

    # chọn solver: OSQP/ECOS/SCS đều được cho QP này
    try:
        prob.solve(solver=solver, verbose=False)
    except Exception:
        prob.solve(solver=cp.SCS, verbose=False)

    if a.value is None:
        raise RuntimeError("Solver failed to find a solution.")
    return a.value  # np.ndarray shape (d,)


def optimal_alpha_general(X, y, ridge=0.0):
    """
    X : [sw, pwd, ebsw, est] (có thể ít cột hơn).
    y : np.ndarray, shape (n_samples,)
    """
    X = np.asarray(X, dtype=float)
    y = np.asarray(y, dtype=float).reshape(-1, 1)

    # Ma trận Gram
    XtX = X.T @ X
    Xty = X.T @ y

    n_features = X.shape[1]
    if ridge > 0:
        XtX_reg = XtX + ridge * np.eye(n_features)
    else:
        XtX_reg = XtX

    coeffs = np.linalg.solve(XtX_reg, Xty).flatten()
    return coeffs


def optimal_alpha_unconstrained(x1, x2, y):
    s11 = (x1 ** 2).sum()
    s22 = (x2 ** 2).sum()
    s12 = (x1 * x2).sum()
    s1y = (x1 * y).sum()
    s2y = (x2 * y).sum()

    delta = s11 * s22 - s12 ** 2
    if np.isclose(delta, 0):
        print("regression is ill-conditioned.")
        delta = 1e-2
    w1 = (s22 * s1y - s12 * s2y) / delta
    w2 = (s11 * s2y - s12 * s1y) / delta

    return w1, w2

def optimal_alpha(x1, x2, y):
    s = x1 - x2
    b = y - x2
    denom = (s ** 2).sum()
    if np.isclose(denom, 0):
        return 0.5
    return np.clip((s * b).sum() / denom, 0.0, 1.0)

    
def Batch_Wasserstein_One_Dimension(X, Y, a=None, b=None, p=2, device="cuda"):
    """
    Compute the true Wasserstein distance in one-dimensional space in a batch.
    :param X: Source samples, shape (num_batches, M, d)
    :param Y: Target samples, shape (num_batches, N, d)
    :param p: Wasserstein-p order
    :return: Tensor of shape (num_batches, d) with Wasserstein distances
    """
    assert X.shape[-1] == Y.shape[-1], "Source and target must have the same dimension"
    assert X.shape[0] == Y.shape[0], "Source batch and target batch must have the same number of batches"

    num_supports_source = X.shape[-2]
    num_supports_target = Y.shape[-2]

    if a is None and b is None:
        X_sorted, _ = torch.sort(X, dim=1)
        Y_sorted, _ = torch.sort(Y, dim=1)

        if num_supports_source == num_supports_target:
            diff_quantiles = torch.abs(X_sorted - Y_sorted)
            if p == 1:
                return diff_quantiles.mean(dim=1)
            else:
                return (diff_quantiles.pow(p).mean(dim=1)).pow(1/p)

        else:
            a_cum_weights = torch.linspace(1.0 / num_supports_source, 1.0, steps=num_supports_source, device=device)
            b_cum_weights = torch.linspace(1.0 / num_supports_target, 1.0, steps=num_supports_target, device=device)
            qs = torch.sort(torch.cat((a_cum_weights, b_cum_weights), 0), dim=0)[0]

            X_quantiles = quantile_function(qs, a_cum_weights, X_sorted)
            Y_quantiles = quantile_function(qs, b_cum_weights, Y_sorted)

            diff_quantiles = torch.abs(X_quantiles - Y_quantiles)

            qs_extended = torch.cat((torch.zeros(1, device=device), qs), dim=0)
            diff_qs = torch.clamp(qs_extended[1:] - qs_extended[:-1], min=1e-6)
            delta = diff_qs.unsqueeze(0).unsqueeze(0).unsqueeze(-1)

            return torch.pow(torch.sum(delta * torch.pow(diff_quantiles, p), dim=-2), 1/p) if p != 1 else torch.sum(delta * diff_quantiles, dim=-2)

    raise NotImplementedError("Weighted Wasserstein not implemented")


def Batch_Sliced_Wasserstein_Distance(X, Y, num_projections=1000, list_projection_vectors=None, p=2, device="cuda", chunk=1000, dtype=torch.float16, return_vectors=False):
    """
    Compute Sliced Wasserstein Distance in batch efficiently. Supports backpropagation.
    
    :param X: Batch of source measures, shape (num_batches, num_supports_source, d)
    :param Y: Batch of target measures, shape (num_batches, num_supports_target, d)
    :param num_projections: Number of projection directions
    :param list_projection_vectors: Optional precomputed projection vectors
    :param p: Wasserstein-p parameter
    :param chunk: Number of projections per chunk for memory management
    :return: Tensor of shape (num_batches,)
    """

    assert X.shape[-1] == Y.shape[-1], "Source and target must have the same dimension"
    assert X.shape[0] == Y.shape[0], "Source and target must have the same number of batches"
    
    dims = X.shape[-1]
    num_batches = X.shape[0]

    if list_projection_vectors is not None:
        num_projections = 0
        for x in list_projection_vectors:
            num_projections += x.shape[0]
        chunk_num_projections = len(list_projection_vectors)

    else:
        if num_projections < chunk:
            chunk = num_projections
            chunk_num_projections = 1
        else:
            chunk_num_projections = num_projections // chunk

    if return_vectors:
        sum_w_p = list()
    else:
        sum_w_p = torch.zeros((num_batches), device=device)

    for i in range(chunk_num_projections):
        if list_projection_vectors is None:
            projection_vectors = generate_uniform_unit_sphere_projections(dim=dims, num_projections=chunk, device=device)
        else:
            projection_vectors = list_projection_vectors[i]

        projection_vectors = projection_vectors.to(dtype)

        X_projection = torch.matmul(X.to(dtype), projection_vectors.t()) # (batch_size, num_examples, chunk)
        Y_projection = torch.matmul(Y.to(dtype), projection_vectors.t()) # (batch_size, num_examples, chunk)

        # for memory efficiency, but slow down the process
        # del projection_vectors
        # gc.collect()
        # torch.cuda.empty_cache()

        w_1d = Batch_Wasserstein_One_Dimension(X=X_projection, Y=Y_projection, p=p, device=device) # (batch_size, chunk)

        if return_vectors:
            sum_w_p.append(w_1d)
        else:
            sum_w_p += torch.sum(torch.pow(w_1d, p), dim=-1)

    if return_vectors:
        sum_w_p = torch.cat(sum_w_p, dim=1)
        return sum_w_p
    else:
        mean_w_p = sum_w_p / num_projections
        return torch.pow(mean_w_p, 1/p) if p != 1 else mean_w_p


def generate_uniform_unit_sphere_projections(dim, requires_grad=False, num_projections=1000, dtype=torch.float32, device="cpu"):
    """
    Generate random uniform unit sphere projections with the same dtype as X and Y.
    """
    projection_matrix = torch.randn((num_projections, dim), dtype=dtype, requires_grad=requires_grad, device=device)
    projection_matrix = projection_matrix / projection_matrix.norm(dim=1, keepdim=True).clamp_min(1e-12)
    if requires_grad:
        projection_matrix = projection_matrix.clone().detach().requires_grad_(True)
    else:
        projection_matrix = projection_matrix.detach()
    return projection_matrix


def quantile_function(qs, cws, xs):
    cws, _ = torch.sort(cws, dim=0)
    qs, _ = torch.sort(qs, dim=0)
    num_dist = xs.shape[0]
    num_projections = xs.shape[-1]
    cws = cws.t().contiguous()
    qs = qs.t().contiguous()
    idx = torch.searchsorted(cws, qs).t()
    return torch.take_along_dim(input=xs, indices=idx.expand(num_projections, idx.shape[-1]).t().expand(num_dist, idx.shape[-1], num_projections), dim=-2)



def _SWGG_loss(X, Y, theta, s=1, std=0, device='cpu'):
    n,dim=X.shape

    X_line=torch.matmul(X,theta.t()).squeeze(1)
    Y_line=torch.matmul(Y,theta.t()).squeeze(1)

    X_line_sort,u=torch.sort(X_line,axis=0)
    Y_line_sort,v=torch.sort(Y_line,axis=0)

    X_sort=X[u]
    Y_sort=Y[v]

    Z_line=(X_line_sort+Y_line_sort)/2
    Z=Z_line.unsqueeze(1) * theta


    W_XZ=torch.sum((X_sort-Z)**2)/n
    W_YZ=torch.sum((Y_sort-Z)**2)/n

    X_line_extend = X_line_sort.repeat_interleave(s,dim=0)
    X_line_extend_blur = X_line_extend + 0.5 * std * torch.randn(X_line_extend.shape,device=device)
    Y_line_extend = Y_line_sort.repeat_interleave(s,dim=0)
    Y_line_extend_blur = Y_line_extend + 0.5 * std * torch.randn(Y_line_extend.shape,device=device)

    X_line_extend_blur_sort,u_b=torch.sort(X_line_extend_blur,axis=0)
    Y_line_extend_blur_sort,v_b=torch.sort(Y_line_extend_blur,axis=0)


    X_extend=X_sort.repeat_interleave(s,dim=0)
    Y_extend=Y_sort.repeat_interleave(s,dim=0)
    X_sort_extend=X_extend[u_b]
    Y_sort_extend=Y_extend[v_b]

    bary_extend=(X_sort_extend+Y_sort_extend)/2
    bary_blur=torch.mean(bary_extend.reshape((n,s,dim)),dim=1)

    W_baryZ=torch.sum((bary_blur-Z)**2)/n
    return -4*W_baryZ+2*W_XZ+2*W_YZ

