import numpy as np
from scipy.io import loadmat
import torch


def _smart_path(path):
    # TODO : add real path parse
    return path

# fisrt for train, second for test
# int for exact number, point for percentage 
default_slice_param = [0.6, 0.4]

def np_list_to_tensor_list(np_list):
    return [torch.from_numpy(np_list[i]).float() for i in range(len(np_list))]

def dict_pattern(path, function, interp_available):
    return {'path': path, 'function': function, 'interp_available': interp_available}

def _concat_on_new_last_dim(arrays):
    arrays = [_array.reshape(*_array.shape, 1) for _array in arrays]
    return np.concatenate(arrays, axis=-1)

def _force_2d(arrays):
    N = arrays.shape[0]
    n = arrays[0].shape[0] * arrays[0].shape[1]
    return np.reshape(arrays, (N, n))

def _get_format_slice_data(length, slice):
    if isinstance(slice[0], int):
        return slice[0], slice[1]

    elif isinstance(slice[0], float):
        assert slice[0] + slice[1] == 1, 'slice sum should be 1'
        _tr = int(length * slice[0])
        _te = int(length * slice[1])
        while _tr + _te > length:
            _tr -= 1
            _te -= 1
        return _tr, _te


class SP_DataLoader(object):
    dataset_available = ['plasmonic2_MF']
    def __init__(self, dataset_name, need_interp=False) -> None:
        self.dataset_info = {
            'plasmonic2_MF': 
                dict_pattern('data/MF_data/plasmonic2_MF.mat', self._general, False),
            } 


        if dataset_name not in self.dataset_info:
            assert False
        if need_interp and self.dataset_info[dataset_name]['interp_available'] is False:
            assert False
        self.dataset_name = dataset_name
        self.need_interp = need_interp

    def get_data(self):
        outputs = self.dataset_info[self.dataset_name]['function']()
        return outputs

    def _general(self):
        _data = loadmat(_smart_path(self.dataset_info[self.dataset_name]['path']))
        x = [_data['X']]
        y = []
        for i in range(len(_data['Y'][0])):
            y.append(_data['Y'][0][i])
        return x, y, None, None

    def _get_distribute(self):
        pass


class Standard_mat_DataLoader(object):
    dataset_available = ['Burget_mfGent_v5',
                        'Heat_mfGent_v5',
                        'Piosson_mfGent_v5',
                        'TopOP_mfGent_v6',]
    def __init__(self, dataset_name, need_interp=False) -> None:
        self.dataset_info = {
            'Burget_mfGent_v5_15': dict_pattern( 'data/Burget_mfGent_v5_15.mat', self._general, True),
            'Heat_mfGent_v5': dict_pattern( 'data/Heat_mfGent_v5.mat', self._general, True),
            'Poisson_mfGent_v5': dict_pattern( 'data/Poisson_mfGent_v5.mat', self._general, True),
            'TopOP_mfGent_v6': dict_pattern( 'data/TopOP_mfGent_v6.mat', self._general, True),
        }
        if dataset_name not in self.dataset_info:
            assert False
        if need_interp and self.dataset_info[dataset_name]['interp_available'] is False:
            assert False 
        self.dataset_name = dataset_name
        self.need_interp = need_interp

    def _general(self):
        _data = loadmat(_smart_path(self.dataset_info[self.dataset_name]['path']))
        x_tr = [torch.from_numpy(_data['xtr'])]
        x_te = [torch.from_numpy(_data['xte'])]
        if self.need_interp is False:
            y_tr = []
            for i in range(len(_data['Ytr'][0])):
                tem = _force_2d(_data['Ytr'][0][i])
                y_tr.append(torch.from_numpy(tem))
                # y_tr.append(torch.from_numpy(_data['Ytr'][0][i]))
            y_te = []
            for i in range(len(_data['Yte'][0])):
                tem = _force_2d(_data['Yte'][0][i])
                y_te.append(torch.from_numpy(tem))
                # y_tr.append(torch.from_numpy(_data['Yte'][0][i]))
        else:
            y_tr = []
            for i in range(len(_data['Ytr_interp'][0])):
                tem = _force_2d(_data['Ytr_interp'][0][i])
                y_tr.append(torch.from_numpy(tem))
                # y_tr.append(torch.from_numpy(_data['Ytr_interp'][0][i]))
            y_te = []
            for i in range(len(_data['Yte_interp'][0])):
                tem = _force_2d(_data['Yte_interp'][0][i])
                y_te.append(torch.from_numpy(tem))
                # y_tr.append(torch.from_numpy(_data['Yte_interp'][0][i]))
        return x_tr, y_tr, x_te, y_te

    def get_data(self):
        outputs = self.dataset_info[self.dataset_name]['function']()
        return outputs
