import torch
from torch.utils.data import Dataset, DataLoader
from Pn_util import *

class PndataTangentGaussian(Dataset):
    def __init__(self, N, Mean, Cov_sqrt):
        # consider tangent space gaussian on Mean (e ~ N(0, Cov) in tangent space of Mean)
        # Mean: a P(n)
        # Cov_sqrt: a vec_dim x vec_dim covariance matrix
        self.dim = Mean.shape[0]
        self.vec_dim = int(self.dim*(self.dim + 1) / 2)
        self.Mean = Mean
        self.Cov_sqrt = Cov_sqrt
        
        e = torch.mm(torch.FloatTensor(N, self.vec_dim).normal_(0.0, 1.0), Cov_sqrt)
        Mean_sqrt = get_sqrt_sym(Mean.view(1, self.dim, self.dim).cuda())
        X = torch.matmul(torch.matmul(Mean_sqrt, Exp_mat(vec2mat(e.cuda()*getCoeff2(self.dim)))), Mean_sqrt)
        X_sqrt, X_invsqrt = get_sqrt_sym(X, returnInvAlso=True)
        
        self.train_data = mat2vec(X).cpu()
        self.train_data_sqrt = X_sqrt.cpu()
        self.train_data_invsqrt = X_invsqrt.cpu()
        self.N = N
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.train_data_sqrt[idx], self.train_data_invsqrt[idx]

class PndataTangentGaussianMixture(Dataset):
    def __init__(self, N, Means, Cov_sqrts):
        # consider a mixture of tangent space gaussian on each Mean (e ~ N(0, Cov) in tangent space of each Mean)
        # assume equal weights for each mixture (fix later...)
        # Means: a list of P(n)s
        # Cov_sqrts: a list of vec_dim x vec_dim covariance matrices
        self.dim = Means.shape[1]
        self.vec_dim = int(self.dim*(self.dim + 1) / 2)
        self.Nmix = Means.shape[0]
        self.Means = Means
        self.Cov_sqrts = Cov_sqrts
        
        Ni = N // self.Nmix
        Nend = N - (self.Nmix-1) * Ni
        X = torch.cuda.FloatTensor(N, self.dim, self.dim).zero_()
        Means_sqrt = get_sqrt_sym(Means.cuda())
        coeff = getCoeff2(self.dim)
        for i in range(self.Nmix-1):
            e = torch.mm(torch.FloatTensor(Ni, self.vec_dim).normal_(0.0, 1.0), Cov_sqrts[i])
            X[(i)*Ni:(i+1)*Ni] = torch.matmul(torch.matmul(Means_sqrt[i:i+1], Exp_mat(vec2mat(e.cuda()*coeff))), Means_sqrt[i:i+1])
        e = torch.mm(torch.FloatTensor(Nend, self.vec_dim).normal_(0.0, 1.0), Cov_sqrts[-1])
        X[(self.Nmix-1)*Ni:] = torch.matmul(torch.matmul(Means_sqrt[-1:], Exp_mat(vec2mat(e.cuda()*coeff))), Means_sqrt[-1:])
        X_sqrt, X_invsqrt = get_sqrt_sym(X, returnInvAlso=True)
        
        self.train_data = mat2vec(X).cpu()
        self.train_data_sqrt = X_sqrt.cpu()
        self.train_data_invsqrt = X_invsqrt.cpu()
        self.N = N
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.train_data_sqrt[idx], self.train_data_invsqrt[idx]

class PndataTangentGaussianMixtureExpanded(Dataset):
    def __init__(self, Pn_dataset, idx = None):
        self.dim = Pn_dataset.Means.shape[1]
        self.vec_dim = int(self.dim*(self.dim + 1) / 2)
        self.Nmix = Pn_dataset.Means.shape[0]
        self.Means = Pn_dataset.Means
        self.Cov_sqrts = Pn_dataset.Cov_sqrts
        if idx is None:
            self.train_data = Pn_dataset.train_data
            self.train_data_sqrt = Pn_dataset.train_data_sqrt
            self.train_data_invsqrt = Pn_dataset.train_data_invsqrt
            self.N = Pn_dataset.N
            
            # additional quantities required for contractive loss
            tempdir = torch.cuda.FloatTensor(self.train_data.shape).zero_()
            X = vec2mat(self.train_data).cuda()
            metric = metric_P_n(X)
            metricInv = metricInv_P_n(X)
            metricDeriv = metricDeriv_P_n(X)
            christoffel_sum = christoffelSum_P_n(X)
            X_sqrt_dirderiv_set = torch.cuda.FloatTensor(self.N, self.dim, self.dim, self.vec_dim).zero_()
            dLog_xdx = torch.cuda.FloatTensor(self.N, self.vec_dim, self.vec_dim).zero_()

            eps = 1e-14
            S, U = batch_eigsym(X)
            S[S<eps] = eps
            logx = mat2vec(Log_mat(X, S = S, U = U))

            for i in range(self.vec_dim):
                tempdir[:,i] = 1
                Xdot = vec2mat(tempdir)
                Xdot_trans = torch.matmul(torch.matmul(U.permute(0,2,1), Xdot), U)
                X_sqrt_dirderiv_set[:,:,:,i] = get_sqrt_sym_DirDeriv(X, Xdot, S = S, U = U, Xdot_trans = Xdot_trans)
                dLog_xdx[:,:,i] = mat2vec(LogDirDeriv(X, Xdot, S = S, U = U, Xdot_trans = Xdot_trans))
                tempdir[:,i] = 0
            self.logx = logx.cpu()
            self.metric = metric.cpu()
            self.metricInv = metricInv.cpu()
            self.metricDeriv = metricDeriv.cpu()
            self.X_sqrt_dirderiv_set = X_sqrt_dirderiv_set.cpu()
            self.dLog_xdx = dLog_xdx.cpu()
            self.christoffel_sum = christoffel_sum.cpu()
            self.S = S.cpu()
            self.U = U.cpu()
        else:
            self.train_data = Pn_dataset.train_data[idx]
            self.train_data_sqrt = Pn_dataset.train_data_sqrt[idx]
            self.train_data_invsqrt = Pn_dataset.train_data_invsqrt[idx]
            self.N = len(idx)
            self.logx = Pn_dataset.logx[idx]
            self.metric = Pn_dataset.metric[idx]
            self.metricInv = Pn_dataset.metricInv[idx]
            self.metricDeriv = Pn_dataset.metricDeriv[idx]
            self.X_sqrt_dirderiv_set = Pn_dataset.X_sqrt_dirderiv_set[idx]
            self.dLog_xdx = Pn_dataset.dLog_xdx[idx]
            self.christoffel_sum = Pn_dataset.christoffel_sum[idx]
            self.S = Pn_dataset.S[idx]
            self.U = Pn_dataset.U[idx]

    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.logx[idx], self.train_data_sqrt[idx], self.train_data_invsqrt[idx], self.metric[idx], self.metricInv[idx], self.metricDeriv[idx], self.X_sqrt_dirderiv_set[idx], self.dLog_xdx[idx], self.christoffel_sum[idx], self.S[idx], self.U[idx]
    
class PndataTangentGaussianMixtureCoordTransform(Dataset):
    def __init__(self, Pn_dataset, A):
        # generate a new dataset from an exisitng dataset by applying given affine transform
        # Pn_dataset: dataset to transform
        # A: affine transform matrix
        self.dim = Pn_dataset.Means.shape[1]
        self.vec_dim = int(self.dim*(self.dim + 1) / 2)
        self.Nmix = Pn_dataset.Means.shape[0]
        
        # transform means
        if A.is_cuda:
            A = A.cpu()
        Means_temp = torch.matmul(torch.matmul(A, Pn_dataset.Means), A.permute(0,2,1))
        self.Means = 0.5*(Means_temp + Means_temp.permute(0,2,1))
        
        # transform covariances
        if not A.is_cuda:
            A = A.cuda()
        M = Pn_dataset.Means.cuda()
        M_sqrt = get_sqrt_sym(M)
        _, AMAT_invsqrt = get_sqrt_sym(torch.matmul(torch.matmul(A,M),A.permute(0,2,1)), returnInvAlso=True)
        R = torch.matmul(torch.matmul(AMAT_invsqrt, A), M_sqrt)
        J = getOrthogonalAffineTransformJacobian(R)
        self.Cov_sqrts = torch.matmul(Pn_dataset.Cov_sqrts.cuda(), J.permute(0,2,1)).cpu()
        
        X = vec2mat(Pn_dataset.train_data.cuda())
        X2 = torch.matmul(torch.matmul(A, X), A.permute(0,2,1))
        X2_sqrt, X2_invsqrt = get_sqrt_sym(X2, returnInvAlso=True)
        
        self.train_data = mat2vec(X2).cpu()
        self.train_data_sqrt = X2_sqrt.cpu()
        self.train_data_invsqrt = X2_invsqrt.cpu()
        self.N = Pn_dataset.N
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.train_data_sqrt[idx], self.train_data_invsqrt[idx]
    
class PndataFromMat(Dataset):
    def __init__(self, X, labels = None):
        N = X.shape[0]
        self.dim = X.shape[1]
        self.vec_dim = int(self.dim*(self.dim + 1) / 2)
        
        X = X.cuda()
        X_sqrt, X_invsqrt = get_sqrt_sym(X, returnInvAlso=True)
        
        self.train_data = mat2vec(X).cpu()
        self.train_data_sqrt = X_sqrt.cpu()
        self.train_data_invsqrt = X_invsqrt.cpu()
        self.N = N
        if labels is not None:
            self.true_labels = labels
    def __len__(self):
        return self.N
    def __getitem__(self, idx):
        return self.train_data[idx], self.train_data_sqrt[idx], self.train_data_invsqrt[idx]