from utils.toolkit import dirichlet_split
import numpy as np
from torch.utils.data import Subset
from utils.data import *


class DataManager():
    def __init__(self, args, task_id):
        self.train_units = None
        self.task_id = task_id
        self.args = args
        self.get_dataset()  # download the dataset for the id-th task
        self.split_data_task_client()  # split the dataset to all clients according to the dirichlet distribution with level alpha

    def get_dataset(self):
        match self.args.dataset_list:
            case 'DIGIT10':
                datalist = iDIGIT10(self.task_id)
                self.data_train, self.data_test, self.num_classes = datalist.download_data()
            case 'PACS':
                datalist = iPACS(self.task_id)
                self.data_train, self.data_test, self.num_classes = datalist.download_data()
            case 'VLCS':
                datalist = iVLCS(self.task_id)
                self.data_train, self.data_test, self.num_classes = datalist.download_data()

    def split_data_task_client(self):
        train_indices = list(range(len(self.data_train)))
        np.random.shuffle(train_indices)
        train_data = Subset(self.data_train, train_indices)
        train_units = []
        train_dirichlet = dirichlet_split(
            train_data,
            num_clients=self.args.num_clients,
            alpha=self.args.dirichlet_coef)
        for sublist in train_dirichlet.values():
            global_indices = [train_indices[i] for i in sublist]
            train_units.append(Subset(self.data_train, global_indices))

        self.train_units = train_units
