from .comlib import *
from . import adversary_transform as at
from . import worker_indices as wi
from . import dataset_operation, worker_indices

class WorkerConf(ABC):
    @abstractmethod
    def getDataset(self,worker_id):
        pass
    
class WorkerConfSimple(ABC):
    def __init__(self,baseDataset):
        self.baseDataset=baseDataset
    @abstractmethod
    def getDataset(self,worker_id):
        return self.baseDataset

class WorkerConfNoTrans(WorkerConf):
    def __init__(self,baseDataset,
                 baseIndicesGen:wi.BaseIndices):
        self.baseDataset=baseDataset
        self.baseIndicesGen=baseIndicesGen

    def getDataset(self,worker_id):
        self.baseDataIndice=self.baseIndicesGen.getIndices(worker_id)
        dataset=torch.utils.data.Subset(self.baseDataset, self.baseDataIndice)
        return dataset
    
    def __str__(self):
        return f"no transform,{self.baseIndicesGen}"


class WorkersConf():  
    def __init__(self,workerConfs:list[WorkerConf]):
        self.workerConfs=workerConfs
        self.worker_num=len(workerConfs)
    def getDatasetList(self): 
        datasetList=[]
        for worker_id in range(self.worker_num):
            workerConf=self.workerConfs[worker_id]
            datasetList.append(workerConf.getDataset(worker_id))
        return datasetList
    
    @staticmethod
    def printListWithRepeat(lst):
        # 创建一个字典来统计每个元素的出现次数
        count_dict = {}
        for item in lst:
            if item in count_dict:
                count_dict[item] += 1
            else:
                count_dict[item] = 1
        
        # 构造表达式
        expression_parts = []
        for key, count in count_dict.items():
            expression_parts.append(f"[{key}]*{count}")
        
        # 将所有部分用 "+" 连接起来
        expression = " + ".join(expression_parts)
        return expression
    # @staticmethod
    # def printFuncList(lst):
    #     return WorkerConf.printListWithRepeat([i.__class__.__name__ for i in lst])

    def __str__(self):
        str_workerConfs=self.printListWithRepeat([str(i) for i in self.workerConfs])

        return f"WorkersConf:{str_workerConfs}"
        # return f"baseDataIndices: {str_baseIndicesGens}\n\
        #         transforms:{str_transforms}"