# gaussian.py

import math
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import numpy as np

__all__ = ['Gaussian']


class GaussianPerp(pl.LightningDataModule):

    def __init__(self, root, train, val):
        if train:

            data_train = np.load(root + 'data_train.npy', allow_pickle=True)
            data = torch.from_numpy(data_train).float()
            self.y = data[:, 0:2]
            self.y[:, 1] = torch.pow(data[:, 1], 3)
            self.label = data[:, 2].long()

            self.s = torch.pow(data[:, 0].unsqueeze(1), 3)

            data_train1 = np.load(root + 'data_train.npy', allow_pickle=True)
            data1 = torch.from_numpy(data_train1).float()
            self.x = data1[:, 0:1]


        elif val:
            data_val = np.load(root+'data_valid.npy', allow_pickle=True)
            data = torch.from_numpy(data_val).float()
            self.y = data[:, 0:2]
            self.y[:, 1] = torch.pow(data[:, 1], 3)
            self.label = data[:, 2].long()

            self.s = torch.pow(data[:, 0].unsqueeze(1), 3)

            data_val1 = np.load(root+'data_valid.npy', allow_pickle=True)
            data1 = torch.from_numpy(data_val1).float()
            self.x = data1[:, 0:1]

        else:
            data_test = np.load(root + 'data_test.npy', allow_pickle=True)
            data = torch.from_numpy(data_test).float()
            self.y = data[:, 0:2]
            self.y[:, 1] = torch.pow(data[:, 1], 3)
            self.label = data[:, 2].long()
            # self.s = data[:, 0].unsqueeze(1)
            self.s = torch.pow(data[:, 0].unsqueeze(1), 3)

            data_test1 = np.load(root + 'data_test.npy', allow_pickle=True)
            data1 = torch.from_numpy(data_test1).float()
            self.x = data1[:, 0:1]

    def __len__(self):
            return len(self.x)

    def __getitem__(self, index):
        x, y, s, label = self.x[index], self.y[index], self.s[index], self.label[index]
        return x, y, s, label

class Gaussian(pl.LightningDataModule):
    def __init__(self, opts):
        super().__init__()
        self.opts = opts
        if opts.ngpu == 0:
            self.pin_memory = False
        else:
            self.pin_memory = True

    def train_dataloader(self):
        batch_size = self.opts.batch_size_train
        if self.opts.dataset_type == 'GaussianPerp':
            dataset = GaussianPerp(root='./data/gaussian/', train=True, val=False)
        else:
            print('Gaussian dataset type does not exist')
            dataset = None
        loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=self.opts.nthreads,
            pin_memory=self.pin_memory,
        )
        return loader

    def val_dataloader(self):
        batch_size = self.opts.batch_size_test

        if self.opts.dataset_type == 'GaussianPerp':
            dataset = GaussianPerp(root='./data/gaussian/', train=False, val=True)
        else:
            print('Gaussian dataset type does not exist')
            dataset = None
        loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self.opts.nthreads,
            pin_memory=self.pin_memory,
        )
        return loader

    def test_dataloader(self):
        batch_size = self.opts.batch_size_test

        if self.opts.dataset_type == 'GaussianPerp':
            dataset = GaussianPerp(root='./data/gaussian/', train=False, val=False)
        else:
            print('Gaussian dataset type does not exist')
            dataset = None
        loader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self.opts.nthreads,
            pin_memory=self.pin_memory,
        )
        return loader