from .comlib import *
from . import worker_conf
from .. import util

class AggWorkersDatasetFromList(Dataset): 
    '''
    input data_index_sets
    worker_name better be int
    '''
    def __init__(self, workerDatasetList: list,wrapper=None):

        self.workerDatasetList = workerDatasetList
        self.wrapper=wrapper
        self.worker_num=len(workerDatasetList)
        self.globalIndex_to_workerId=[]
        self.globalIndex_to_localIndex=[]
        self.dataNumPerWorker=[]
        
        for i,dataset in enumerate(workerDatasetList):
            workerDataNum=len(dataset)
            self.globalIndex_to_workerId.extend([i]*workerDataNum)
            self.globalIndex_to_localIndex.extend(torch.arange(workerDataNum))
            self.dataNumPerWorker.append(workerDataNum)

    def __getitem__(self, item):
        workerId=self.globalIndex_to_workerId[item]
        localIndex=self.globalIndex_to_localIndex[item]
        data_and_target=self.workerDatasetList[workerId][localIndex]
        if self.wrapper is None:
            return (workerId,*data_and_target)
        else:
            return self.wrapper.wrap(workerId,*data_and_target)
    
    def __len__(self):
        return torch.sum(torch.tensor(self.dataNumPerWorker))
            
    
    def getSubWokerDataset(self,worker_ids:list):
        '''
        the id of each worker will not remain the same
        '''
        subDataList=[self.workerDatasetList[workerId] for workerId in worker_ids]
        return AggWorkersDatasetFromList(subDataList,self.wrapper)

    def getSubWokerDataNum(self,worker_ids:list):
        dataNumPerWorkerTensor=torch.tensor(self.dataNumPerWorker)
        return dataNumPerWorkerTensor[worker_ids]
        # subWorkerTensor=torch.tensor(worker_ids)
        # return dataNumPerWorkerTensor[subWorkerTensor]
        
    def randomSubWokerDataset(self,choose_num):
        worker_ids=random.sample(range(self.worker_num), k=choose_num)
        return self.getSubWokerDataset(worker_ids)
    
    # def getOneWokerDataset(self,workerId):
    #     return self.workerDatasetList[workerId]

    

# dataloader=DataLoader(Agg_workers_dataset(),batch_size=batch_size)

class AggWorkersDatasetFromConf(AggWorkersDatasetFromList):
    def __init__(self, conf: worker_conf.WorkersConf=None,wrapper=None,workerDatasetList=None):
        if workerDatasetList is None:
            workerIndicesMap=conf.getDatasetList()
        super().__init__(workerIndicesMap,wrapper)


class AggWorkersDatasetWithRedundancy(AggWorkersDatasetFromList):
    def __init__(self, conf=None, wrapper=None,redundentMap=None,workerDatasetList=None):
        if workerDatasetList is None:
            workerIndicesMap=conf.getDatasetList()
        super().__init__(workerIndicesMap, wrapper)
        self.redundentTensor=util.RedundentTensor(redundentMap)

        
    # def __init__(self, conf: worker_conf.WorkersConf,wrapper=None,redundentMap=None):
    #     workerIndicesMap=conf.getDatasetList()
    #     super().__init__(workerIndicesMap,wrapper)
    #     self.redundentTensor=util.RedundentTensor(redundentMap)


    # def getSubWokerDataset(self,worker_ids:list):
    #     '''
    #     the id of each worker will not remain the same
    #     '''
    #     rt=self.redundentTensor.get_sub(worker_ids)
    #     subDataList=[self.workerDatasetList[workerId] for workerId in rt.unique_keys]
    #     return AggWorkersDatasetFromList(subDataList,self.wrapper)
    
    # def getSubWokerDataNum(self,worker_ids:list):
    #     rt=self.redundentTensor.get_sub(worker_ids)
    #     dataNumPerWorkerTensor=torch.tensor(self.dataNumPerWorker)
    #     subWorkerTensor=torch.tensor(rt.unique_keys)
    #     return dataNumPerWorkerTensor[subWorkerTensor]
    
    
    # def restoreFromUnique(self,workerValues,worker_ids:list):
    #     rt=self.redundentTensor.get_sub(worker_ids)
    #     rt.set(workerValues)
    #     return rt.get()


    # def get_unique_worker_ids(self,worker_ids):
    #     if self.repetition is None:
    #         return worker_ids
    #     else:
    #         temp=[self.repetition[i] for i in worker_ids]
    #         return []
        

