import numpy as np
import torch
import random



class OmnipushDataset(torch.utils.data.Dataset):
    def __init__(self, ns, phase, device='cpu'):
        super().__init__()
        x_data = np.load('data/OmniPush/x_data.npy')
        y_data = np.load('data/OmniPush/y_data.npy')

        random.seed(18)
        temp = list(zip(x_data, y_data))

        self.device = device

        random.shuffle(temp)
        res1, res2 = zip(*temp)

        x_data = list(res1)
        y_data = list(res2)
        print(f"x_data {phase}")
        print(x_data[0][0:2])
        print(f"y_data {phase}")
        print(y_data[0][0:2])

        n = len(x_data)

        if phase == 'train':
            self.x_data = x_data[0: int(0.8 * n)]
            self.y_data = y_data[0: int(0.8 * n)]
        if phase == 'test':
            self.x_data = x_data[int(0.8 * n):]
            self.y_data = y_data[int(0.8 * n):]
        
        self.ns = ns
        
    def __len__(self):
        return len(self.x_data)
        
    def __getitem__(self, idx):
        x = torch.from_numpy(self.x_data[idx]).float()
        y = torch.from_numpy(self.y_data[idx]).float()
        
        shift = np.random.randint(250 - self.ns)
        x_support = x[shift : shift + self.ns]
        y_support = y[shift : shift + self.ns]
        
        x_query = torch.cat((x[shift + self.ns:], x[:shift]), 0)
        y_query = torch.cat((y[shift + self.ns:], y[:shift]), 0)
        
        params = torch.Tensor([0])
        
        return x_support.to(self.device), y_support.to(self.device), x_query.to(self.device), y_query.to(self.device)


if __name__ == '__main__':
    dset = OmnipushDataset(20, 'train')
    hej = dset[0]
    print(hej[0][0:5])
        