import torch
from torch import vmap

# input z:(70,num_samples, 10)
# output u:(70,10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def k_fn(x, y, sigma):
    diffs = (x - y) / sigma
    r2 = torch.sum((diffs**2))
    return torch.exp(-0.5 * r2)

kernel = vmap(vmap(k_fn, (None,0,None)), (0,None,None))


def kme(x, y, sigma):
    return torch.mean(kernel(x,y, sigma))


def MMD(x, y):
    m, n = x.shape[0], y.shape[0]
    return (torch.sum(kernel(x, x)) - torch.sum(vmap(k_fn, (0,0))(x, x)))/m/(m-1) + (torch.sum(kernel(y, y)) - torch.sum(vmap(k_fn, (0,0))(y, y)))/n/(n-1) - 2*kme(x,y)


def sort_eigen(x):
    # eigenvectors might be not unique by multiplying -1. 
    # Since we want to use the standardized eigen, we standardize them such that the first term of each eigen is positive
    sign = torch.sign(x[0]).reshape(1,x.shape[1])
    sign = torch.tile(sign, (x.shape[0],1))
    x = x * sign
    return x

def compute_u(z_train, num_sensors = 10, sigma=15.0):


    num_examples = z_train.shape[0]
    idx = torch.triu_indices(num_examples, num_examples)
    

    def kme_from_idx(row, col, sigma):
        return kme(z_train[row[None]][0], z_train[col[None]][0], sigma)

    cov_entries = vmap(kme_from_idx, (0,0,None))(idx[0],idx[1],sigma).reshape(-1)
    cov = torch.zeros((num_examples,num_examples), device=z_train.device)
    cov[idx[0],idx[1]] = cov_entries
    cov = cov + cov.T - torch.diag(torch.diag(cov))
    
    return cov

    



def compute_u_from_samples(x_train: torch.Tensor, samples: torch.Tensor, evecs_t: torch.Tensor, row_mean: torch.Tensor, matrix_mean: torch.Tensor):
    num_examples, num_samples, dim = x_train.shape[0], x_train.shape[1], x_train.shape[-1]
    samples = samples.reshape(1,num_samples,dim)
    samples_inner_product = vmap(vmap(kme,(None,0)),(0,None))(x_train.reshape(num_examples,num_samples,dim),samples.reshape(1,num_samples,dim))
    
    test_total = -torch.mean(samples_inner_product.flatten())*torch.ones((num_examples,1),device=samples_inner_product.device) + row_mean.reshape(num_examples,1) + matrix_mean * torch.ones((num_examples,1), device=samples_inner_product.device) +  samples_inner_product
    u = evecs_t.T @ test_total
    u = u.reshape(1,-1)
    return u

def compute_u_from_samples_2(x_train: torch.Tensor, samples: torch.Tensor, row_mean: torch.Tensor, matrix_mean: torch.Tensor):
    num_examples, num_samples, dim = x_train.shape[0], x_train.shape[1], x_train.shape[-1]
    samples = samples.reshape(1,num_samples,dim)
    samples_inner_product = vmap(vmap(kme,(None,0)),(0,None))(x_train.reshape(num_examples,num_samples,dim),samples.reshape(1,num_samples,dim))

    u = -torch.mean(samples_inner_product.flatten())*torch.ones((num_examples,1), device=samples_inner_product.device) + row_mean.reshape(num_examples,1) + matrix_mean * torch.ones((num_examples,1), device=samples_inner_product.device) +  samples_inner_product
    u = u.reshape(1,-1)
    return u

def compute_u_from_samples_3(x_train: torch.Tensor, samples: torch.Tensor, sigma):
    num_examples, num_samples, dim = x_train.shape[0], x_train.shape[1], x_train.shape[-1]
    samples = samples.reshape(num_samples,dim)
    u = vmap(kme,(0,None,None))(x_train, samples, sigma).reshape(num_examples,)
    return u

def compute_u_train_and_test(z_train, z_test, num_sensors = 10, sigma=15.0):
    u_train = compute_u(z_train, sigma=sigma)
    u_test = vmap(compute_u_from_samples_3,(None,0, None))(z_train,z_test, sigma)

    return u_train, u_test
                                    
