
from .comlib import *
from . import adversary_transform, backdoor, dataset_operation
from torch.distributions.dirichlet import Dirichlet
from .. import util

class BaseIndices():
    def __init__(self,random_seed):
        self.random_seed=random_seed
    @staticmethod
    def combine_seed(*args):
        return ''.join(map(str, args))
    @staticmethod
    def set_random_seed(*args):
        combined_seed = BaseIndices.combine_seed(*args)
        random.seed(combined_seed)

    # @staticmethod
    # def set_np_random_seed(*args):
    #     combined_seed = BaseIndices.combine_seed(*args)
    #     int_seed = hash(combined_seed) % (2**32)
    #     np.random.seed(int_seed)

    def getIndices(self, worker_id):
        ...

class IIDIndices(BaseIndices):
    def __init__(self,random_seed,data_num,indexNum=None,indexList=None,replace=True):
        super().__init__(random_seed)
        self.indexNum=indexNum
        self.data_num=data_num
        self.replace=replace
        if indexList is None:
            self.indices=range(self.indexNum)
        else:
            self.indices=indexList
        
    def getIndices(self, worker_id):
        self.set_random_seed(self.random_seed, worker_id)
        if self.replace:
            return random.choices(self.indices,k=self.data_num)
        else:
            return random.sample(self.indices,self.data_num)
        
    def __str__(self):
        return f"{self.__class__.__name__},data_num:{self.data_num}"

class SpecLabelNumIndices(BaseIndices):
    def __init__(self,random_seed,global_dataset,class_num,dataNumList,classindex=None):
        super().__init__(random_seed)
        self.global_dataset=global_dataset
        self.class_num=class_num
        self.dataNumList=dataNumList
        if classindex is None:
            self.classindex=self.getClassIndex(self.global_dataset)
        else:
            self.classindex=classindex

    @staticmethod
    def getClassIndex(dataset):
        labels=dataset.targets
        n_class=len(dataset.classes)
        class_index_dict={}
        for c in range(n_class):
            idx_c = torch.where(labels == c)[0]
            class_index_dict[c]=idx_c
        return class_index_dict

    @staticmethod
    def getClassIndex2(dataset):
        class_index_dict=collections.defaultdict(lambda : [])
        for i in range(len(dataset)):
            target=dataset[i][1]
            class_index_dict[target].append(i)
        # print(class_index_dict)
        return class_index_dict

    def getIndices(self,worker_id):
        self.set_random_seed(self.random_seed, worker_id)
        indices=[]
        for i in range(self.class_num):
            indices.extend(random.choices(self.classindex[i], k=self.dataNumList[i]))
        return indices
    def __str__(self):
        return f"{self.__class__.__name__},dataNumList:{self.dataNumList}"

class EqualLabelNumIndices(SpecLabelNumIndices):
    def __init__(self,random_seed, global_dataset, class_num, data_num_per_class, classindex=None):
        dataNumList=[data_num_per_class]*class_num
        super().__init__(random_seed,global_dataset, class_num, dataNumList, classindex)




class RepeatedIndices(BaseIndices):
    def __init__(self,random_seed,indexNum,data_num,actual_data_num=1):
        super().__init__(random_seed)
        self.indexNum=indexNum
        self.data_num=data_num
        self.actual_data_num=actual_data_num
        
    def getIndices(self, worker_id):
        self.set_random_seed(self.random_seed, worker_id)
        indices=random.sample(self.indexNum, k=self.actual_data_num)
        indices=indices*int(self.data_num/self.actual_data_num)
        return indices

class FixIndices(BaseIndices):
    def __init__(self,indices):
        self.indices=indices
        
    def getIndices(self, worker_id):
        return self.indices
    
class ClassDirichletIndices(BaseIndices): 
    def __init__(self, random_seed, global_dataset,data_num, alpha):
        super().__init__(random_seed)
        self.global_dataset=global_dataset
        self.data_num=data_num
        self.class_num=len(global_dataset.classes)
        self.alpha=alpha
        # if classindex is None:
        #     self.classindex=self.getClassIndex(self.global_dataset)
        # else:
        #     self.classindex=classindex 

    def countDataNumPerClass(self,indices):
        labels=self.global_dataset.targets
        labels=list(labels[indices])
        return [labels.count(i) for i in range(self.class_num)]
    
    def getFromProportion(self,proportion):
        global_indices=list(range(len(self.global_dataset)))
        labels=self.global_dataset.targets
        
        indices=random.choices(global_indices,weights=proportion[labels], k=self.data_num)
        return indices
    
    def getIndices(self,worker_id):
        self.set_random_seed(self.random_seed, worker_id)
        rs=util.get_random_state(self.random_seed,worker_id)
        a_list=np.full((self.class_num,), self.alpha)
        proportion = rs.dirichlet(a_list, size=None)
        indices=self.getFromProportion(proportion)
        return indices


    # def __init__(self, random_seed, global_dataset, class_num, alpha, classindex=None):
    #     self.set_random_seed(self.random_seed, worker_id)
    #     classes=list(range(class_num))
    #     dist = Dirichlet(torch.full((class_num,), alpha))
    #     proportions = dist.sample((class_num,))
    #     dataNumList=random.choices(self.classindex[i], k=self.dataNumList[i])
    #     # np.random.choice(classes,p=proportions)
    #     super().__init__(random_seed, global_dataset, class_num, dataNumList, classindex)

    # def __init__(self,random_seed,alpha):
    #     super().__init__(random_seed)
    #     self.indices=alpha

    # def getIndices(self, worker_id):
    #     return self.indices
        
    # def dirichlet_split(labels, alpha, n_clients):
    #     """
    #     labels: 1-D tensor of class indices
    #     alpha : Dirichlet 浓度参数
    #     returns: list of tensors, 每个元素是该客户端的样本索引
    #     """
    #     n_classes = labels.max().item() + 1
    #     # 为每个类别生成客户端分布 (n_classes, n_clients)
    #     dist = Dirichlet(torch.full((n_clients,), alpha))
    #     proportions = dist.sample((n_classes,))   # (n_classes, n_clients)

    #     client_indices = [[] for _ in range(n_clients)]
    #     for c in range(n_classes):
    #         idx_c = torch.where(labels == c)[0]
    #         splits = (proportions[c] * len(idx_c)).int()
    #         splits[-1] = len(idx_c) - splits[:-1].sum()
    #         idx_split = torch.split(idx_c, splits.tolist())
    #         for cid, ids in enumerate(idx_split):
    #             client_indices[cid].append(ids)
    #     return [torch.cat(ids) for ids in client_indices]
