import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class MyDataset(Dataset):
    def __init__(self, file_path, k_shot):
        data = torch.load(file_path)
        self.xt_gt = data['state']
        self.x0_gt = data['init']
        self.labels = data['label'].reshape(-1).int().numpy()
        self.k_shot = k_shot
        np.random.seed(0)

    def __len__(self):
        self.split()
        return self.xt_gt.shape[0]

    def __getitem__(self, idx):
        xt_gt = self.xt_gt[idx]
        x0_gt = self.x0_gt[idx]
        labels = self.labels[idx]
        idx_D = self.spt_idx[self.labels[idx]]
        if idx in idx_D:
            idx_D = np.delete(idx_D, np.where(idx_D == idx))
        else:
            idx_D = np.delete(idx_D, self.k_shot)
        xt_D = self.xt_gt[idx_D]
        x0_D = self.x0_gt[idx_D]
        return xt_gt, xt_D, x0_gt, x0_D

    def split(self):
        self.spt_idx = {}
        label_idx = {}
        for label in np.unique(self.labels):
            idx = np.where(self.labels == label)[0]
            label_idx[label] = idx
        for label_id, samples in label_idx.items():
            sample_idx = np.arange(0, len(samples))
            np.random.shuffle(sample_idx)
            spt = np.sort(sample_idx[0:self.k_shot + 1])
            self.spt_idx[label_id] = samples[spt]


def MyDataLoader(file_path, k_shot, batch_size, shuffle):
    return DataLoader(MyDataset(file_path, k_shot), batch_size=batch_size, shuffle=shuffle)


if __name__ == "__main__":
    train_loader = MyDataLoader('data_train.pt', 7, 10, True)
    for xt_gt, xt_D, x0_gt in train_loader:
        print(xt_gt.shape, xt_D.shape, x0_gt.shape)
        break