"""
Basic utilities for neural SDEs.
"""
import torch
import numpy as np
from torchsde import sdeint
class MLP(torch.nn.Module):
    def __init__(self, D_in, H, D_out, layers=2, resnet=False):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(MLP, self).__init__()
        self.layers = layers
        self.linears = torch.nn.ModuleList([torch.nn.Linear(D_in, H)])
        for i in range(layers - 2):
            self.linears.append(torch.nn.Linear(H, H))
        self.linears.append(torch.nn.Linear(H, D_out))
        self.resnet = resnet


    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        x = self.linears[0](x)
        x = torch.relu(x)
        for i in range(1,self.layers - 1):
            if self.resnet:
                y = self.linears[i](x)
                x = torch.relu(y) + x
            else:
                x = self.linears[i](x)
                x = torch.relu(x)
        x = self.linears[-1](x)
        return x

class neuralSDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'
    def __init__(self, state_size, brownian_size, hidden_size, batch_size, layers=2, resnet=False):
        super().__init__()
        self.state_size = state_size
        self.brownian_size = brownian_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.mu = MLP(state_size, hidden_size, state_size, layers, resnet=resnet)
        self.sigma = MLP(state_size, hidden_size, state_size * brownian_size, layers, resnet=resnet)
        
    # Drift
    def f(self, t, y):
        return self.mu(y)  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        return self.sigma(y).view(self.batch_size, 
                                  self.state_size, 
                                  self.brownian_size)
    
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)


    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu1 = self.linear1(x)
        h_relu1 = torch.relu(h_relu1)
        y_pred = self.linear2(h_relu1)
        return y_pred
    
class neuralSDE_legacy(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'
    def __init__(self, state_size, brownian_size, hidden_size, batch_size):
        super().__init__()
        self.state_size = state_size
        self.brownian_size = brownian_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.mu = TwoLayerNet(state_size, hidden_size,
                                  state_size)
        self.sigma = TwoLayerNet(state_size, hidden_size,
                                  state_size * brownian_size)
        
    # Drift
    def f(self, t, y):
        return self.mu(y)  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        return self.sigma(y).view(self.batch_size, 
                                  self.state_size, 
                                  self.brownian_size)

class temporal_neuralSDE_legacy(torch.nn.Module):
    # t is input as an additional parameter/dimensionn to the neural network
    noise_type = 'general'
    sde_type = 'ito'
    def __init__(self, state_size, brownian_size, hidden_size, batch_size):
        super().__init__()
        self.state_size = state_size
        self.brownian_size = brownian_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.mu = TwoLayerNet(state_size + 1, hidden_size,
                                  state_size)
        self.sigma = TwoLayerNet(state_size + 1, hidden_size,
                                  state_size * brownian_size)
    
    # Drift
    def f(self, t, y):
        # check if t is a torch tensor
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        # t should be of size torch.Size([])
        assert t.shape == torch.Size([]), f"t should be of size torch.Size([]), got {t.shape}"
        t = t.expand(y.size(0), 1) # shape (batch_size, 1)
        ty = torch.cat([t, y], dim=1)
        return self.mu(ty)

    # Diffusion
    def g(self, t, y):
        # check if t is a torch tensor
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        # t should be of size torch.Size([])
        assert t.shape == torch.Size([]), f"t should be of size torch.Size([]), got {t.shape}"
        t = t.expand(y.size(0), 1) # shape (batch_size, 1)
        ty = torch.cat([t, y], dim=1)
        return self.sigma(ty).view(self.batch_size, 
                                  self.state_size, 
                                  self.brownian_size)


import ot
def distance(P, Q):
    cost_matrix = ot.dist(P, Q,metric='sqeuclidean')
    return cost_matrix

def w2_coupled(y, y_pred):
    batch_size = y.shape[1]
    state_size = y.shape[2]
    t_size = y.shape[0]
    # y_cat = y.reshape(batch_size, t_size * state_size)
    # y_pred_cat = torch.zeros(batch_size, t_size * state_size)
    # for i in range(1, t_size):
    #     for j in range(batch_size):
    #         y_pred_cat[j, state_size * (i-1): state_size * (i)] += y_pred[i, j, :] 
    
    # y_cat = y_cat.reshape(batch_size, t_size, state_size)
    # y_pred_cat = y_pred_cat.reshape(batch_size, t_size, state_size)
    y_permute = y.permute(1,0,2)
    y_pred_permute = y_pred.permute(1,0,2)
    y_cat = y_permute.reshape(batch_size, t_size * state_size)
    y_pred_cat = y_pred_permute.reshape(batch_size, t_size * state_size)
    weights = torch.tensor([1/batch_size for _ in range(batch_size)])

    loss = ot.emd2(weights, weights, distance(y_cat, y_pred_cat))
    return loss

def w2_decoupled_numerical(y, y_pred):
    batch_size = y.shape[1]
    state_size = y.shape[2]
    t_size = y.shape[0]
    loss = 0
    for i in range(1, t_size):
        weights = torch.tensor([1/batch_size for _ in range(batch_size)])

        loss += ot.emd2(weights, weights, distance(y[i, :, :], y_pred[i, :, :]))
    
    return loss
class temporal_neuralSDE(torch.nn.Module):
    # t is input as an additional parameter/dimensionn to the neural network
    noise_type = 'general'
    sde_type = 'ito'
    def __init__(self, state_size, brownian_size, hidden_size, batch_size):
        super().__init__()
        self.state_size = state_size
        self.brownian_size = brownian_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.mu = MLP(state_size + 1, hidden_size,
                                  state_size)
        self.sigma = MLP(state_size + 1, hidden_size,
                                  state_size * brownian_size)
    
    # Drift
    def f(self, t, y):
        # check if t is a torch tensor
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        # t should be of size torch.Size([])
        assert t.shape == torch.Size([]), f"t should be of size torch.Size([]), got {t.shape}"
        t = t.expand(y.size(0), 1) # shape (batch_size, 1)
        ty = torch.cat([t, y], dim=1)
        return self.mu(ty)

    # Diffusion
    def g(self, t, y):
        # check if t is a torch tensor
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        # t should be of size torch.Size([])
        assert t.shape == torch.Size([]), f"t should be of size torch.Size([]), got {t.shape}"
        t = t.expand(y.size(0), 1) # shape (batch_size, 1)
        ty = torch.cat([t, y], dim=1)
        return self.sigma(ty).view(self.batch_size, 
                                  self.state_size, 
                                  self.brownian_size)


# def test_temporal_neuralSDE():
#     # Create an instance of the temporal_neuralSDE class
#     state_size = 1
#     brownian_size = 1
#     hidden_size = 10
#     batch_size = 4
#     sde = temporal_neuralSDE(state_size, brownian_size, hidden_size, batch_size)

#     # Generate random input data for t and y
#     t = 0.0
#     y = torch.rand(batch_size, state_size)

#     # Pass the data through the f and g functions
#     f_output = sde.f(t, y)
#     g_output = sde.g(t, y)

#     # Check the output dimensions
#     assert f_output.shape == (batch_size, state_size), f"Unexpected shape for f: {f_output.shape}"
#     assert g_output.shape == (batch_size, state_size, brownian_size), f"Unexpected shape for g: {g_output.shape}"

#     print("f output shape:", f_output.shape)
#     print("g output shape:", g_output.shape)
#     print("Test passed!")

# test_temporal_neuralSDE()
# Define the functions
def quantile(samples_sorted):
    def quantile_func(p):
        return samples_sorted[torch.floor(p * len(samples_sorted)).long()]
    return quantile_func

def W2_distance(u_samples, v_samples):
    u_samples_sorted, _ = u_samples.sort()
    v_samples_sorted, _ = v_samples.sort()
    u_icdf_grids = torch.linspace(0, 1, steps=len(u_samples)+1)
    v_icdf_grids = torch.linspace(0, 1, steps=len(v_samples)+1)
    grids = torch.unique(torch.cat((u_icdf_grids, v_icdf_grids))).sort()[0]
    U_icdf = quantile(u_samples_sorted)(grids[:-1])
    V_icdf = quantile(v_samples_sorted)(grids[:-1])
    return torch.sum((U_icdf - V_icdf) ** 2 * torch.diff(grids))


# not differentiable in pytorch?
# def W1_distance(u_samples, v_samples):
#     # Sorting the samples
#     u_samples_sorted, _ = torch.sort(u_samples)
#     v_samples_sorted, _ = torch.sort(v_samples)
    
#     # Combining and sorting all samples
#     all_samples, _ = torch.sort(torch.unique(torch.cat((u_samples_sorted, v_samples_sorted), dim=0)))
    
#     # Compute CDF for u and v samples
#     u_cdf = cdf(u_samples_sorted, all_samples[:-1])
#     v_cdf = cdf(v_samples_sorted, all_samples[:-1])
    
#     # Compute the 1-Wasserstein distance
#     wsd = torch.sum(torch.abs(u_cdf - v_cdf) * (all_samples[1:] - all_samples[:-1]))
    
#     return wsd

# def cdf(u_samples, x):
#     return torch.sum(u_samples[:, None] <= x[None, :], dim=0).float() / u_samples.size(0)



def W2_loss_alt(u_pred):
    # if u_pred and u_truth have the same shape
    if u_pred.shape == u_truth.shape:
        criterion = torch.nn.MSELoss(reduction='sum')
        ys_truth_sort = torch.zeros(u_truth.shape[0], u_truth.shape[1], u_truth.shape[2])
        for i in range(u_truth.shape[0]):
            bs = torch.tensor([float(u_pred[i][j][0]) for j in range(u_pred[i].shape[0])]).sort().indices
            ys_truth_re = torch.tensor([float(u_truth[i][j][0]) for j in range(u_truth[i].shape[0])])
            bs_truth = ys_truth_re.sort().indices
            for j in range(u_truth.shape[1]):
                    ys_truth_sort[i,int(bs[j]),0] = u_truth[i,int(bs_truth[j]),0]
        loss = criterion(u_pred, ys_truth_sort) / u_pred.shape[1]
        return loss
    else:
        print("when u_pred and u_truth have different shapes, it is not implemented yet, NaN returned")
        return torch.tensor(float('nan'))

def mean2_var_distance(u_samples, v_samples):
    u_mean = torch.mean(u_samples)
    v_mean = torch.mean(v_samples)
    u_var = torch.var(u_samples)
    v_var = torch.var(v_samples)
    return torch.abs(u_mean - v_mean)**2 + torch.abs(u_var - v_var)

def mse_distance(u_samples, v_samples):
    return torch.mean((u_samples - v_samples)**2)


# def apprx_loglik(u_truth, times, sde):
#     t_size, sample_size, n = u_truth.shape
#     This version works for arbitrary dimension, but extremely memory inefficient
#     ℓ = torch.tensor(0.0)
    
#     for i in range(t_size - 1):
#         for j in range(sample_size):
#             dt = times[i + 1] - times[i]
            
#             dxdt = f(sde, u_truth[i,j,:]).squeeze(-1)
#             dσdt = g(sde, u_truth[i,j,:]).squeeze(-1)
            
            
#             if n == 1:
#                 x = u_truth[i + 1,j] - u_truth[i,j] - dxdt * dt
#                 Σ = (dσdt**2 * dt).squeeze() * torch.eye(n)
#             else:
#                 x = u_truth[i + 1,j] - u_truth[i,j] - dxdt * dt
#                 Σ = (dσdt**2).squeeze() * torch.eye(n) * dt
            
#             # Computing the log likelihood using batch operations
#             logdeterminant = torch.logdet(Σ)
#             inverse_Σ = torch.inverse(Σ)
#             xm = x.unsqueeze(-1)
#             bilinear_term = torch.mm(torch.mm(xm.transpose(0, 1), inverse_Σ), xm).squeeze()  # batched matrix multiplication

#             ℓ += (-0.5 * n * torch.log(torch.tensor(2 * 3.1415926)) - 0.5 * logdeterminant - 0.5 * bilinear_term)
        
#     return ℓ/t_size/sample_size

# u_truth = torch.randn(10, 100, 1)
# times = torch.linspace(0, 1, 10)
# sde = neuralSDE(1,1,1,100)
def apprx_loglik(u_truth, times, sde):
    t_size, sample_size, n = u_truth.shape
    # the implementation is specific to n=1 case for now
    # in this case, u_truth can be considered to be under diagonal noise
    ℓ = torch.tensor(0.0)
    
    for i in range(t_size - 1):
        dt = times[i + 1] - times[i]
        
        dxdt = f(sde, u_truth[i,:])
        dσdt = g(sde, u_truth[i,:]).squeeze(-1)
        
        
        if sample_size == 1:
            x = u_truth[i + 1] - u_truth[i] - dxdt * dt
            Σ = (dσdt**2 * dt).squeeze() * torch.eye(sample_size)
        else:
            x = u_truth[i + 1] - u_truth[i] - dxdt * dt
            Σ = (dσdt**2).squeeze() * torch.eye(sample_size) * dt
        
        # Computing the log likelihood using batch operations
        logdeterminant = torch.logdet(Σ)
        inverse_Σ = torch.inverse(Σ)
        bilinear_term = torch.mm(torch.mm(x.transpose(0, 1), inverse_Σ), x).squeeze()  # batched matrix multiplication

        ℓ += (-0.5 * sample_size * torch.log(torch.tensor(2 * 3.1415926)) - 0.5 * logdeterminant - 0.5 * bilinear_term)
        
    return ℓ

# Radial Basis Function kernel (Gaussian kernel) for kernel density estimation in MMD (Maximum Mean Discrepancy) loss
class RBF(torch.nn.Module): 


    def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
        super().__init__()
        self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
        self.bandwidth = bandwidth

    def get_bandwidth(self, L2_distances):
        if self.bandwidth is None:
            n_samples = L2_distances.shape[0]
            return L2_distances.data.sum() / (n_samples ** 2 - n_samples)

        return self.bandwidth

    def forward(self, X):
        L2_distances = torch.cdist(X, X) ** 2
        return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)

# Maximum Mean Discrepancy loss
# Ref: https://github.com/yiftachbeer/mmd_loss_pytorch
# Ref: https://arxiv.org/abs/1502.02761
class MMDLoss(torch.nn.Module):

    def __init__(self, kernel=RBF()):
        super().__init__()
        self.kernel = kernel

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))

        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY

mmd_distance = MMDLoss()

# sdeint requires a specific batch dimensio for input, we need to pad and split
# the input accordingly, and reassemble the output
def predict(sde, u0, ts, *args, **kwargs):
    # we must first convert u0 into batches of size n_batch=200
    # there are several cases to consider, first, if u0 is smaller than n_batch
    # then we can just pad it with zeros
    # if u0 is larger than n_batch, then we need to split it into batches
    # and then concatenate the results
    n_samples = u0.shape[0]
    if hasattr(sde, 'batch_size'):
        n_batch = sde.batch_size
    else:
        n_batch = n_samples
    if n_samples < n_batch:
        # pad with zeros
        u0 = torch.cat((u0, torch.zeros(n_batch - n_samples, 1)), 0)
        u_pred = sdeint(sde, u0, ts, *args, **kwargs)
        # only return the first n_samples
        return u_pred[:,:n_samples,:]
    elif n_samples > n_batch:
        # split into batches
        n_batches = int(np.ceil(n_samples / n_batch))
        u_preds = []
        for i in range(n_batches):
            u0_batch = u0[i*n_batch:(i+1)*n_batch,:]
            if u0_batch.shape[0] < n_batch:
                # pad with zeros for the last batch
                u0_batch = torch.cat((u0_batch, torch.zeros(n_batch - u0_batch.shape[0], 1)), 0)
            u_pred_batch = sdeint(sde, u0_batch, ts)
            u_preds.append(u_pred_batch)
        # concatenate the results
        u_pred = torch.cat(u_preds, 1)
        # only return the first n_samples
        return u_pred[:,:n_samples,:]
    else:
        # u0 is already of size n_batch
        return sdeint(sde, u0, ts, *args, **kwargs)

def _evaluate_sde_component(neuralsde, ts, xs, component='f'):
    n_sample = xs.shape[0]
    n_batch = getattr(neuralsde, 'batch_size', n_sample)
    # if neuralsde does not have state_size, then it is a scalar SDE
    if not hasattr(neuralsde, 'state_size'):
        state_size = 1
    else:
        state_size = neuralsde.state_size
    if n_sample <= n_batch:
        xs_padded = torch.cat((xs, torch.zeros(n_batch - n_sample, state_size)), 0)
        values_list = [getattr(neuralsde, component)(t_scalar, xs_padded)[:n_sample, :] for t_scalar in ts]
        values = torch.stack(values_list)
        return values # shape = (t_size, n_sample, state_size)
    else:
        n_batches = int(np.ceil(n_sample / n_batch))
        values = []
        for i in range(n_batches):
            xs_batch = xs[i*n_batch:(i+1)*n_batch,:]
            ts_batch = ts
            if xs_batch.shape[0] < n_batch:
                xs_padded_batch = torch.cat((xs_batch, torch.zeros(n_batch - xs_batch.shape[0], state_size)), 0)
            else:
                xs_padded_batch = xs_batch
            values_batch_list = [getattr(neuralsde, component)(t_scalar, xs_padded_batch)[:xs_batch.shape[0], :] for t_scalar in ts_batch]
            values_batch = torch.cat(values_batch_list, 0)
            
            values.append(values_batch)
        values = torch.cat(values, 0)
        return values


def f(sde, xs, ts=torch.tensor([0.0])):
    # ts should be of shape (n_sample, )
    if xs.dim() == 1:
        xs = xs.unsqueeze(-1)
    assert xs.dim() == 2
    if isinstance(ts, float):
        ts = torch.tensor([ts])
    if ts.dim() == 0:
        ts = ts.unsqueeze(0)
    if ts.dim() == 1:
        return _evaluate_sde_component(sde, ts, xs, component='f').squeeze(0)
    else:
        return _evaluate_sde_component(sde, ts, xs, component='f')

def g(sde, xs, ts=torch.tensor([0.0])):
    # ts should be of shape (n_sample, )
    if xs.dim() == 1:
        xs = xs.unsqueeze(-1)
    assert xs.dim() == 2
    if isinstance(ts, float):
        ts = torch.tensor([ts])
    if ts.dim() == 0:
        ts = ts.unsqueeze(0)
    if ts.dim() == 1:
        return torch.abs(_evaluate_sde_component(sde, ts, xs, component='g').squeeze(0))
    else:
        return torch.abs(_evaluate_sde_component(sde, ts, xs, component='g'))

def Sigma(sde, xs, ts=torch.tensor([0.0])):
    # ts should be of shape (n_sample, )
    if xs.dim() == 1:
        xs = xs.unsqueeze(-1)
    assert xs.dim() == 2
    if isinstance(ts, float):
        ts = torch.tensor([ts])
    if ts.dim() == 0:
        ts = ts.unsqueeze(0)
    if ts.dim() == 1:
        with torch.no_grad():
            sigma = _evaluate_sde_component(sde, ts, xs, component='g').squeeze(0)
            # evaluate sigma^T matmul sigma, i.e., the covariance matrix
            return torch.matmul(sigma,sigma.transpose(1, 2))
    else:
        with torch.no_grad():
            sigma = _evaluate_sde_component(sde, ts, xs, component='g')
            # evaluate sigma^T matmul sigma
            return torch.matmul(sigma,sigma.transpose(1, 2))


def rel_err_f(neuralsde, sde, u_truth, ts):
    """
    Compute the relative error between the estimated drift from neuralsde and the true drift from sde for each trajectory sample.
    """
    with torch.no_grad():
        t_size, sample_size, n = u_truth.shape
        # maximum sample size set to 200
        # max_sample_size = 200
        # if sample_size > max_sample_size:
        #     sample_size = max_sample_size
        #     u_truth = u_truth[:,:max_sample_size,:]
        squared_errors = []
        squared_norms_f = []
        
        for idx, t in enumerate(ts):
            # Extract the trajectory samples corresponding to the current time t
            u_truth_at_t = u_truth[idx]
            
            # Compute the estimated drift component for this set of trajectory samples
            f_hat_vals_at_t = f(neuralsde, u_truth_at_t, t)
            
            # Compute the true drift component for this set of trajectory samples
            f_vals_at_t = f(sde, u_truth_at_t, t)
            
            # Compute the squared error for this set of trajectory samples
            squared_error_at_t = torch.norm(f_hat_vals_at_t - f_vals_at_t, dim=-1)**2
            squared_errors.append(squared_error_at_t)
            
            # Compute the squared magnitude of the true drift component for this set of trajectory samples
            squared_norm_f_at_t = torch.norm(f_vals_at_t, dim=-1)**2
            squared_norms_f.append(squared_norm_f_at_t)
        
        # Aggregate the squared errors and norms from all time points
        total_squared_error = torch.sum(torch.cat(squared_errors))
        total_squared_norm_f = torch.sum(torch.cat(squared_norms_f))
        
        # Compute the relative error
        rel_error = total_squared_error / total_squared_norm_f
    
    return torch.sqrt(rel_error).detach().numpy()

def rel_err_g(neuralsde, sde, u_truth, ts):
    """
    Compute the relative error between the estimated diffusion from neuralsde and the true diffusion from sde for each trajectory sample.
    """
    with torch.no_grad():
        t_size, sample_size, n = u_truth.shape
        # max_sample_size = 200
        # if sample_size > max_sample_size:
        #     sample_size = max_sample_size
        #     u_truth = u_truth[:,:max_sample_size,:]


        squared_errors = []
        squared_norms_g = []
        
        for idx, t in enumerate(ts):
            # Extract the trajectory samples corresponding to the current time t
            u_truth_at_t = u_truth[idx]
            
            # Compute the estimated diffusion component for this set of trajectory samples
            g_hat_vals_at_t = g(neuralsde, u_truth_at_t, t)
            
            # Compute the true diffusion component for this set of trajectory samples
            g_vals_at_t = g(sde, u_truth_at_t, t)
            
            # Compute the squared error for this set of trajectory samples
            squared_error_at_t = torch.norm(g_hat_vals_at_t - g_vals_at_t, dim=-1)**2
            squared_errors.append(squared_error_at_t)
            
            # Compute the squared magnitude of the true diffusion component for this set of trajectory samples
            squared_norm_g_at_t = torch.norm(g_vals_at_t, dim=-1)**2
            squared_norms_g.append(squared_norm_g_at_t)

        # Aggregate the squared errors and norms from all time points
        total_squared_error = torch.sum(torch.cat(squared_errors))
        total_squared_norm_g = torch.sum(torch.cat(squared_norms_g))

        # Compute the relative error
        rel_error = total_squared_error / total_squared_norm_g

        return torch.sqrt(rel_error).detach().numpy()

def rel_err_Sigma(neuralsde, sde, u_truth, ts):
    """
    Compute the relative error between the estimated diffusion from neuralsde and the true diffusion from sde for each trajectory sample.
    """
    with torch.no_grad():
        t_size, sample_size, n = u_truth.shape
        # max_sample_size = 200
        # if sample_size > max_sample_size:
        #     sample_size = max_sample_size
        #     u_truth = u_truth[:,:max_sample_size,:]


        squared_errors = []
        squared_norms_σ = []
        
        for idx, t in enumerate(ts):
            # Extract the trajectory samples corresponding to the current time t
            u_truth_at_t = u_truth[idx]
            
            # Compute the estimated diffusion component for this set of trajectory samples
            σ_hat_vals_at_t = Sigma(neuralsde, u_truth_at_t, t)
            
            # Compute the true diffusion component for this set of trajectory samples
            σ_vals_at_t = Sigma(sde, u_truth_at_t, t)
            
            # Compute the squared error for this set of trajectory samples using matrix norm
            squared_error_at_t = torch.norm(σ_hat_vals_at_t - σ_vals_at_t, dim=(-2,-1))**2
            squared_errors.append(squared_error_at_t)
            
            # Compute the squared magnitude of the true diffusion component for this set of trajectory samples
            squared_norm_σ_at_t = torch.norm(σ_vals_at_t, dim=(-2,-1))**2
            squared_norms_σ.append(squared_norm_σ_at_t)

        # Aggregate the squared errors and norms from all time points
        total_squared_error = torch.sum(torch.cat(squared_errors))
        total_squared_norm_σ = torch.sum(torch.cat(squared_norms_σ))

        # Compute the relative error
        rel_error = total_squared_error / total_squared_norm_σ

        return torch.sqrt(rel_error).detach().numpy()

import torch

def rotate_2d_vector(v, theta_degree):
    n_sample = v.shape[0]
    assert v.shape == (n_sample, 2)
    # if theta degree is a tensor
    if isinstance(theta_degree, torch.Tensor):
        theta_rad = theta_degree / 180.0 * torch.pi
    else:
        theta_rad = torch.tensor(theta_degree / 180.0 * torch.pi)
    rotation_matrix = torch.tensor([[torch.cos(theta_rad), -torch.sin(theta_rad)],
                                    [torch.sin(theta_rad),  torch.cos(theta_rad)]], dtype=torch.float32)
    # set the data type of rotation_matrix to be the same as v
    rotation_matrix = rotation_matrix.to(v.dtype)
    # copy n_sample times of rotation_matrix
    rotation_matrix = rotation_matrix.expand(n_sample, 2, 2)
    # reshape v to be a matrix of shape (n_sample, 2, 1)
    v = v.unsqueeze(-1)
    return torch.matmul(rotation_matrix, v).squeeze(-1)

def group_vectors_by_angle_intervals(vectors, N):
    # Compute the angles for each vector in degrees
    angles = torch.atan2(vectors[:, 1], vectors[:, 0])
    angles_degrees = convert_to_degrees(angles)
    
    # Calculate the interval size
    interval_size = 360 / N
    
    # Determine which group each angle belongs to
    group_indices = (angles_degrees // interval_size).long()
    
    # Group vectors based on their group index
    groups = [torch.where(group_indices == i)[0] for i in range(N)]
    
    return groups

def convert_to_degrees(angles):
    # Convert to [0, 2*pi] range
    angles = (angles + 2 * np.pi) % (2 * np.pi)
    # Convert to degrees
    angles_degrees = angles * (180 / np.pi)
    return angles_degrees

def radially_sliced_W2_distance(u_samples, v_samples):
    # assume u_samples of shape (n_sample, 2)
    # assume u_samples is uesd as a surrogate for u_pred, v_samples is used as a surrogate for u_truth
    assert u_samples.shape[1] == 2
    assert v_samples.shape[1] == 2
    assert len(u_samples) == len(v_samples)
    assert u_samples.dim() == 2
    assert v_samples.dim() == 2
    N_slices = 10
    u_groups = group_vectors_by_angle_intervals(u_samples, N_slices)
    v_groups = group_vectors_by_angle_intervals(v_samples, N_slices)
    W2_groups = []
    # only for sliced_W2
    for u_group, v_group in zip(u_groups, v_groups):
        # if u_group or v_group is empty, then the W2 distance is zero
        if len(u_group) == 0:
            W2_groups.append(torch.tensor(0.0))
        elif len(v_group) == 0:
            W2_groups.append(torch.mean(torch.norm(u_samples[u_group], dim=1)**2))
        else:
            u_norms = torch.norm(u_samples[u_group], dim=1)
            v_norms = torch.norm(v_samples[v_group], dim=1)
            W2_groups.append(W2_distance(u_norms, v_norms))
    # weighted average of W2 distance according to the number of u_samples in each group
    return torch.sum(torch.stack(W2_groups) * torch.tensor([len(u_group) for u_group in u_groups])) / len(u_samples)

    # from utils.sde_utils import radially_sliced_W2_distance as distance
    # def loss(u_pred):
    #     loss_cul = 0
    #     for t in range(0, example2d_truth.t_size):
    #         loss_cul += distance(u_pred[t], u_truth[t])
    #     return loss_cul
