import os

import numpy as np
import math
import torch
from torch.utils.data import DataLoader, Dataset


def chunks(l, n):
    m = n * math.ceil(len(l) / n)
    for i in range(0, m, n):
        yield l[i:i + n]
class SplitDataset(Dataset):
    """weights dataset."""

    def __init__(self, root='Ftask', dataset='cifar100', split='train', subset=1, max_ways=10):
        super(SplitDataset, self).__init__()
        self.dataset = dataset

        self.split = split
        self.subset= subset
        self.root = root
        self.max_ways= max_ways

        datapath = os.path.join(root, dataset, f'{split}_data.pt')
        # datapath = os.path.join(root, dataset, f'{split}_data.pt')
        data = torch.load(datapath)
        n_classes = data['n_classes']
        labels = data['targets']
        data = data['features']

        if n_classes>max_ways:
            label = [i for i in range(n_classes)]
            targets = list(chunks(label, max_ways))[subset]
            tgts =[]
            dt =[]
            for cs in targets:
                cidx = torch.where(labels==cs)
                tgts.append(labels[cidx])
                dt.append(data[cidx])
            labels = torch.cat(tgts, dim=-1).tolist()
            data = torch.cat(dt, dim=0)
            mins = min(labels)
            lp = [y-mins for y in labels]
            labels = torch.tensor(lp, dtype=torch.int64)
            n_classes = len(targets)
        self.labels = labels
        self.data = data
        self.n_classes = n_classes





    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # print(idx)
        features = self.data[idx].type(torch.float32)

        target = self.labels[idx].type(torch.int64)
        # n_classes = self.n_classes
        return features, target