import os
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import nibabel as nib
from torch_batch_svd import svd
from data_util import vector2tensor_1dim, Log_vec2Log
from tensor_data_util import ExpJacobian, LogJacobian

class N_n_DataSet(Dataset):
    def __init__(self, filelist, 
                 returnforDAE = False, match_num_per_subject = None, shuffle = False, randPos = False, 
                 dx = 2, roi_range = None, data_prefix = '_fix'):
        ####### data_prefix '_reg': values calculated for fixed and regularized data
        ####### data_prefix '_fix': values calculated for fixed (removed non positive-definite voxels) data
        ####### data_prefix '_cor0.1': values calculated for fixed and corrupted data
        super(N_n_DataSet, self).__init__()
        self.randPos = randPos
        self.dx = dx
        self.returnforDAE = returnforDAE
        self.file_list = filelist
        posAndLogvec = []
        posAndCov = []
        posMetric = []
        posMetric_sqrt = []
        posMetricInv_sqrt = []
        logJacobian = []
        covEigVec = []
        covEigVal = []
        split_nums = []
        
        for i, filename in enumerate(self.file_list):
            pos_name = filename[:-7]+'_posData.pt'
            covvec_name = filename[:-7]+data_prefix+'_covvec.pt'
            metric_name = filename[:-7]+data_prefix+'_posMetric.pt'
            metric_sqrt_name = filename[:-7]+data_prefix+'_posMetric_sqrt_sym.pt'
            metricInv_sqrt_name = filename[:-7]+data_prefix+'_posMetricInv_sqrt_sym.pt'
            covEigVec_name = filename[:-7]+data_prefix+'_coveigvec.pt'
            covEigVal_name = filename[:-7]+data_prefix+'_coveigval.pt'
            posAndCov.append(torch.cat((torch.load(pos_name), torch.load(covvec_name)), 1))
            posMetric.append(torch.load(metric_name))
            posMetric_sqrt.append(torch.load(metric_sqrt_name))
            posMetricInv_sqrt.append(torch.load(metricInv_sqrt_name))
            covEigVec.append(torch.load(covEigVec_name))
            covEigVal.append(torch.load(covEigVal_name))
            split_nums.append(posAndCov[-1].shape[0])
            
        self.posAndCov = torch.cat(posAndCov, 0)
        self.covInv = torch.cat(posMetric, 0)
        self.covInv_sqrt = torch.cat(posMetric_sqrt, 0)
        self.cov_sqrt = torch.cat(posMetricInv_sqrt, 0)
        self.cov_eigvec = torch.cat(covEigVec, 0)
        self.cov_eigval = torch.cat(covEigVal, 0)
        self.split_nums = split_nums
        
        if roi_range is not None:
            posAndCov = []
            posMetric = []
            posMetric_sqrt = []
            posMetricInv_sqrt = []
            covEigVec = []
            covEigVal = []
            split_nums = []
            temp_nums = self.split_nums
            start_num = 0
            end_num = temp_nums[0]
            for i in range(len(temp_nums)):
                if i > 0:
                    start_num += temp_nums[i-1]
                    end_num += temp_nums[i]
                temp_val = self.posAndCov[start_num:end_num]
                range_idx = (temp_val[:,0] >= roi_range[0]*dx) & (temp_val[:,0] < roi_range[1]*dx) \
                & (temp_val[:,1] >= roi_range[2]*dx) & (temp_val[:,1] < roi_range[3]*dx) \
                & (temp_val[:,2] >= roi_range[4]*dx) & (temp_val[:,2] < roi_range[5]*dx)
                posAndCov.append(temp_val[range_idx])
                posMetric.append(self.covInv[start_num:end_num][range_idx])
                posMetric_sqrt.append(self.covInv_sqrt[start_num:end_num][range_idx])
                posMetricInv_sqrt.append(self.cov_sqrt[start_num:end_num][range_idx])
                covEigVec.append(self.cov_eigvec[start_num:end_num][range_idx])
                covEigVal.append(self.cov_eigval[start_num:end_num][range_idx])
                
                split_nums.append(temp_val[range_idx].shape[0])
                
            self.posAndCov = torch.cat(posAndCov, 0)
            self.posAndCov[:,0] -= roi_range[0]*dx
            self.posAndCov[:,1] -= roi_range[2]*dx
            self.posAndCov[:,2] -= roi_range[4]*dx
            self.covInv = torch.cat(posMetric, 0)
            self.covInv_sqrt = torch.cat(posMetric_sqrt, 0)
            self.cov_sqrt = torch.cat(posMetricInv_sqrt, 0)
            self.cov_eigvec = torch.cat(covEigVec, 0)
            self.cov_eigval = torch.cat(covEigVal, 0)
            self.split_nums = split_nums
        
        ################### translation 
        self.posMean = torch.mean(self.posAndCov[:,:3], dim=0).view(1,-1)
        self.posAndCov[:,:3] = self.posAndCov[:,:3] - torch.mean(self.posAndCov[:,:3], dim=0).view(1,-1)
        ###################
        
        ################### calculate logJacobian (required when logged input is used)
        self.logJacobian = [None]*self.posAndCov.shape[0]
        ###################
        
        self.train_data = self.posAndCov
            
    def __len__(self):
        return self.posAndCov.shape[0]
    
    def __getitem__(self, idx):
        data = self.posAndCov[idx]
        if self.randPos:
            # add random deviation in [-dx/2, dx/2] to the position part
            data[:3] += torch.rand(3)*self.dx - 0.5*self.dx
        if self.returnforDAE:
            return data, [], []
            
        else:
            return data, self.covInv_sqrt[idx], self.cov_sqrt[idx], \
                self.cov_eigvec[idx], self.cov_eigval[idx]

from torch.utils.data import Dataset
class N_n_ToyDataSet(Dataset):
    def __init__(self, posAndCov, returnforDAE = False, randPos = False, dx = 2):
        super(N_n_ToyDataSet, self).__init__()
        self.randPos = randPos
        self.dx = dx
        
        if posAndCov.shape[1] == 5:
            self.pos_dim = 2
        else:
            self.pos_dim = 3
        
        self.returnforDAE = returnforDAE
        
        self.posAndCov = posAndCov
        
        U,S,V = svd(vector2tensor_1dim(posAndCov[:,self.pos_dim:]).view(-1,self.pos_dim,self.pos_dim))

        # correct sign of the eigenvalues
        UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
        S[UtV < 0] = - S[UtV < 0]
        eps = 1e-5
        self.covInv_sqrt = torch.matmul(torch.matmul(U, 
                                       torch.diag_embed(1.0/torch.sqrt(S + eps))), 
                                          U.permute(0,2,1))
        self.cov_sqrt = torch.matmul(U, torch.matmul(torch.diag_embed(torch.sqrt(S)), 
                                          U.permute(0,2,1)))
        self.cov_eigvec = U.cpu()
        self.cov_eigval = S.cpu()
        
        ################### translation 
        self.posMean = torch.mean(self.posAndCov[:,:self.pos_dim], dim=0).view(1,-1)
        self.posAndCov[:,:self.pos_dim] = self.posAndCov[:,:self.pos_dim] \
        - torch.mean(self.posAndCov[:,:self.pos_dim], dim=0).view(1,-1)
        ###################
        
        if self.returnforDAE:
            self.train_data = self.posAndCov
            
    def __len__(self):
        return self.posAndCov.shape[0]
    
    def __getitem__(self, idx):
        data = self.posAndCov[idx]
        if self.randPos:
            # add random deviation in [-dx/2, dx/2] to the position part
            data[:self.pos_dim] += torch.rand(self.pos_dim)*self.dx - 0.5*self.dx
        if self.returnforDAE:
            return data, [], []
            
        else:
            return data, self.covInv_sqrt[idx], self.cov_sqrt[idx], \
            self.cov_eigvec[idx], self.cov_eigval[idx]
        
        
class DTI_DataSet(Dataset):
    def __init__(self, filelist, returnPosOnly = False, 
                 match_num_per_subject = None, shuffle = False, randPos = False, 
                 dx = 2, roi_range = None, data_prefix = '_reg'):
        ####### data_prefix '_fix': values calculated for fixed (removed non positive-definite voxels) data
        super(DTI_DataSet, self).__init__()
        self.randPos = randPos
        self.dx = dx
        self.returnPosOnly = returnPosOnly
        self.file_list = filelist
        data_suffix = "_tensor_cropped.nii.gz"
        posAndCov = []
        split_nums = []
        
        for i, filename in enumerate(self.file_list):
            pos_name = filename[:-7]+'_posData.pt'
            covvec_name = filename[:-7]+data_prefix+'_covvec.pt'
            posAndCov.append(torch.cat((torch.load(pos_name), torch.load(covvec_name)), 1))
            split_nums.append(posAndCov[-1].shape[0])
            
        self.posAndCov = torch.cat(posAndCov, 0)
        self.split_nums = split_nums
        
        if roi_range is not None:
            posAndCov = []
            split_nums = []
            temp_nums = self.split_nums
            start_num = 0
            end_num = temp_nums[0]
            for i in range(len(temp_nums)):
                if i > 0:
                    start_num += temp_nums[i-1]
                    end_num += temp_nums[i]
                temp_val = self.posAndCov[start_num:end_num]
                range_idx = (temp_val[:,0] >= roi_range[0]*dx) & (temp_val[:,0] < roi_range[1]*dx) \
                & (temp_val[:,1] >= roi_range[2]*dx) & (temp_val[:,1] < roi_range[3]*dx) \
                & (temp_val[:,2] >= roi_range[4]*dx) & (temp_val[:,2] < roi_range[5]*dx)
                posAndCov.append(temp_val[range_idx])
                split_nums.append(temp_val[range_idx].shape[0])
                
            self.posAndCov = torch.cat(posAndCov, 0)
            self.posAndCov[:,0] -= roi_range[0]*dx
            self.posAndCov[:,1] -= roi_range[2]*dx
            self.posAndCov[:,2] -= roi_range[4]*dx
            self.split_nums = split_nums
        
        ################### translation 
        self.posMean = torch.mean(self.posAndCov[:,:3], dim=0).view(1,-1)
        self.posAndCov[:,:3] = self.posAndCov[:,:3] - torch.mean(self.posAndCov[:,:3], dim=0).view(1,-1)
        ###################
        
        if self.returnPosOnly:
            self.train_data = self.posAndCov[:,:3]
        else:
            self.train_data = self.posAndCov
                
    def __len__(self):
        return self.posAndCov.shape[0]
    
    def __getitem__(self, idx):
        data = self.posAndCov[idx]
        if self.randPos:
            # add random deviation in [-dx/2, dx/2] to the position part
            data[:3] += torch.rand(3)*self.dx - 0.5*self.dx
        if self.returnPosOnly:
            return data[:3]
        return data
            
