import torchvision.transforms as transforms
import torch
from abc import abstractmethod
from dataloader.target_loader.metaDataset import UserSetDataset, EpisodicBatchSampler
from dataloader.transforms import STFT, GADF, Resize, Raw1D, Normalize1D


class TransformLoader:
    def __init__(self, image_size, nperseg):
        self.image_size = image_size
        self.nperseg = nperseg

    def get_composed_transform(self, method='stft', stft_mean=[], stft_std=[], gadf_mean=[], gadf_std=[], raw_mean=None, raw_std=None):
        if method == 'stft':
            transform = transforms.Compose(
                [STFT(self.nperseg), Resize(self.image_size), transforms.Normalize(stft_mean, stft_std)])
        elif method == 'gadf':
            transform = transforms.Compose(
                [GADF(self.image_size), transforms.Normalize(gadf_mean, gadf_std)])
        elif method == 'raw':
            if raw_mean is None:
                transform = transforms.Compose(
                    [Raw1D()])
            else:
                transform = transforms.Compose(
                    [Raw1D(), Normalize1D(raw_mean, raw_std)  ])
        return transform


class DataManager:
    @abstractmethod
    def get_data_loader(self, data_file, aug):
        pass


class HHAR(DataManager):
    n_way = 6
    seed = [0, 1,  4,  7, 8, 9]
    all_mode = ['p', 'i', 'd']  # personalized, independent, dependent
    nperseg = 30
    data_file = ''
    stft_mean = [0.0562, 0.0418, 0.0380, 0.0454, 0.0370, 0.0426]
    stft_std = [0.1503, 0.1098, 0.1016, 0.1216, 0.0998, 0.1134]
    gadf_mean = [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
    gadf_std = [0.5311, 0.5214, 0.5349, 0.5735, 0.5504, 0.5798]

    def __init__(self, data_root, image_size, trans_method, n_support, n_query, mode, n_eposide=10, users=None , **kwargs):
        super(HHAR, self).__init__()
        self.data_root = data_root
        self.image_size = image_size

        self.trans_method = trans_method
        self.n_eposide = n_eposide
        self.trans_loader = TransformLoader(self.image_size, self.nperseg)
        self.n_support = n_support
        self.n_query = n_query
        # random get 5 user for train and test, dependent
        self.mode = mode

        if users is None:
            self.users = [str(i) for i in range(0, 9)]
        else:
            self.users = users

    def get_data_loader(self, data_file, aug=False):  # parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(self.trans_method, self.stft_mean, self.stft_std,
                                                             self.gadf_mean, self.gadf_std)
        self.dataset = UserSetDataset(data_file=data_file, users=self.users, n_way=self.n_way, n_shot=self.n_support,
                                      n_query=self.n_query, transform=transform, mode=self.mode,
                                      data_root=self.data_root)
        self.sampler = EpisodicBatchSampler(len(self.dataset), self.n_way, self.n_eposide, self.seed,
                                            self.dataset.samplers, self.mode)
        data_loader_params = dict(batch_sampler=self.sampler, num_workers=0, pin_memory=True)
        data_loader = torch.utils.data.DataLoader(self.dataset, **data_loader_params)
        return data_loader


class WESAD(DataManager):
    n_way = 3
    seed = [0, 1, 4, 7, 8, 9]
    all_mode = ['p', 'i', 'd']  # personalized, independent, dependent
    nperseg = 75
    data_file = ''
    stft_mean = [0.0562]
    stft_std = [0.1503]
    gadf_mean = [0.0000]
    gadf_std = [0.5311]
    raw_mean =[480.0982]
    raw_std = [106.1475]

    def __init__(self, data_root, image_size, trans_method, n_support, n_query, mode, n_eposide=10, users=None,
                 **kwargs):
        super(WESAD, self).__init__()
        self.data_root = data_root
        self.image_size = image_size

        self.trans_method = trans_method
        self.n_eposide = n_eposide
        self.trans_loader = TransformLoader(self.image_size, self.nperseg)
        self.n_support = n_support
        self.n_query = n_query
        # random get 5 user for train and test, dependent
        self.mode = mode

        if users is None:
            self.users = [str(i) for i in range(0, 9)]
        else:
            self.users = users

    def get_data_loader(self, data_file, aug=False):  # parameters that would change on train/val set
        transform = self.trans_loader.get_composed_transform(self.trans_method, self.stft_mean, self.stft_std,
                                                             self.gadf_mean, self.gadf_std, self.raw_mean, self.raw_std)
        self.dataset = UserSetDataset(data_file=data_file, users=self.users, n_way=self.n_way, n_shot=self.n_support,
                                      n_query=self.n_query, transform=transform, mode=self.mode,
                                      data_root=self.data_root)
        self.sampler = EpisodicBatchSampler(len(self.dataset), self.n_way, self.n_eposide, self.seed,
                                            self.dataset.samplers, self.mode)
        data_loader_params = dict(batch_sampler=self.sampler, num_workers=0, pin_memory=True)
        data_loader = torch.utils.data.DataLoader(self.dataset, **data_loader_params)
        return data_loader


def target_dataloader(dataset='HHAR'):
    if dataset == 'HHAR':
        return HHAR
    elif dataset == 'WESAD':
        return WESAD

class prettyfloat(float):
    def __repr__(self):
        return "%0.4f" % self
    def __str__(self):
        return "%0.4f" % self


if __name__ == "__main__":
    # 'p' n_shot=10 , n_query=5
    # 'd' n_shot=10 , n_query=15
    # 'i' n_shot=10 , n_query=5

    # dataroot = 'D:/Dataset/HybridCrossDataset/HHar'
    # data_file = 'filelists/HHAR/hhar.pkl'
    dataroot = 'D:\Dataset\HybridCrossDataset\Datasets\Target\WESAD'
    data_file = 'filelists/WESAD/WESAD.pkl'
    n_shot = 5
    n_query = 15
    image_size = 224
    n_eposide = 3
    trans_method = 'raw'
    # trans_method = 'm'
    mode = 'd'  # 'p','d'
    users_list = ['0', '1','2','3','4','5','6','7','8']
    # hhar = HHAR(dataroot, image_size, trans_method, n_shot, n_query, mode=mode, n_eposide=n_eposide, users=users_list )
    hhar = WESAD(dataroot, image_size, trans_method, n_shot, n_query, mode=mode, n_eposide=n_eposide, users=users_list)
    hhar_loader = hhar.get_data_loader(data_file)

    print(hhar_loader.__len__())
    for x, y in hhar_loader:
        # print(y)
        print(x.shape, y.shape)
        # y = torch.transpose(y, 0, 1)
        # print(y.shape)
        # y = torch.reshape(y, (20 * 6,))
        # print(y.shape)
        # print(y
        # for j in range(6):
        #     print(x[:,:,j,:,:].shape)
        #     print(torch.mean(x[:,:,j,:,:]), torch.var(x[:,:,j,:,:]))
        # break


    # data = []
    # for x, y in hhar_loader:
    #     data.append(x)
    # data = torch.stack(data)
    # data = data.contiguous().view((data.size(0) * data.size(1) * data.size(2), data.size(3), data.size(4)))
    # mean = data.mean(axis=(0, 2))
    # std = data.std(axis=(0, 2))
    # means = map(prettyfloat, mean.tolist())
    # stds = map(prettyfloat, std.tolist())
    # print('Mean: {}'.format(list(means)))
    # print('STD: {}'.format(list(stds)))
