import numpy as np

import torch
from torch.utils.data import Dataset


def pos_code(x):
    max_range_l, max_range_u = -100, 300
    pos = np.sin((x-100)*np.pi/(400))
    return pos


class CustomDataset(Dataset):
    def __init__(self, data, target, ranges, transform=None, target_transform=None, pos_encoding=True):
        self.transform = transform
        self.target_transform = target_transform
        self.pos_encoding = pos_encoding

        self.data = data
        self.targets = target
        self.ranges = ranges

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        sample_data = self.data[idx]
        label = self.targets[idx]
        ran = self.ranges[idx]

        if self.pos_encoding:
            _x = np.linspace(ran[0], ran[1], len(sample_data))
            pos = pos_code(_x)
            sample_data = sample_data + pos
        
        if self.transform is not None:
            sample_data = self.transform(sample_data)

        if self.target_transform is not None:
            label = self.target_transform(label)

        sample_data = torch.Tensor(sample_data)
        label = torch.Tensor([label])
        return sample_data, label


def get_loader(root, b_size, normalize):
    data = torch.load(root + '/' + 'kde_auc_data_softmax-no_pos.tar')
    x, y, ranges = data['x'], data['y'], data['range']
    permu = np.random.permutation(x.shape[0])
    x = x[permu]
    y = y[permu]
    ranges = [ranges[i] for i in permu]
    train_num = int(x.shape[0] * 0.7)
    train_x, train_y, train_ranges = x[:train_num], y[:train_num], ranges[:train_num]
    test_x, test_y, test_ranges = x[train_num:], y[train_num:], ranges[train_num:]

    mu = np.mean(train_x, axis=0)
    std = np.std(train_x, axis=0)

    if normalize:
        train_x = (train_x - mu) / (std + 1e-5)
        test_x = (test_x - mu) / (std + 1e-5)

    train_set = CustomDataset(train_x, train_y, train_ranges, pos_encoding=False)
    test_set = CustomDataset(test_x, test_y, test_ranges, pos_encoding=False)

    train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=b_size,
        shuffle=True,
        num_workers=0,
        # sampler=subset_sampler
    )
    test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=b_size,
        shuffle=False,
        num_workers=0,
        # sampler=subset_sampler
    )
    return train_loader, test_loader, mu, std
