from .. import dataset,util
from typing import Literal
from dataclasses import dataclass,asdict,field
from .conf import NormalByzantineConf
from .wrapper import create_wrapper,WrapNeglectId

def getIidConf(train_dataset,random_seed,data_num):
    iidGen=dataset.IIDIndices(random_seed,data_num,indexNum=len(train_dataset),replace=True)
    iidConf=dataset.WorkerConfNoTrans(train_dataset,iidGen)
    return iidConf

def getNoisyConf(train_dataset,random_seed,data_num,get_wrong_ratio):
    iidGen=dataset.IIDIndices(random_seed,data_num,indexNum=len(train_dataset),replace=True)
    wlTrans=dataset.WrongLabelTransform()
    conf=dataset.NoisyConf(train_dataset,iidGen,wlTrans,random_seed,get_wrong_ratio)
    return conf

def getWlConf(train_dataset,random_seed,data_num):
    iidGen=dataset.IIDIndices(random_seed,data_num,indexNum=len(train_dataset),replace=True)
    wlTrans=dataset.WrongLabelTransform()
    wlConf=dataset.AdversaryConf(train_dataset,iidGen,wlTrans)
    return wlConf

def get_elGen(train_dataset,random_seed,data_num,class_num=10):
    data_num_per_class=int(data_num/class_num)
    classindex=dataset.SpecLabelNumIndices.getClassIndex(train_dataset)
    elGen=dataset.EqualLabelNumIndices(random_seed, train_dataset, class_num, data_num_per_class, classindex=classindex)
    return elGen

def getEwlConf(train_dataset,random_seed,data_num,class_num=10):
    elGen=get_elGen(train_dataset,random_seed,data_num,class_num=class_num)
    wlTrans=dataset.WrongLabelTransform()
    ewlConf=dataset.AdversaryConf(train_dataset,elGen,wlTrans)
    return ewlConf


def getSwlConf(train_dataset,gen):
    data_index=gen.getIndices(0)
    # print(len(data_index))
    sGen=dataset.FixIndices(data_index)
    wlTrans=dataset.WrongLabelTransform()
    return dataset.AdversaryConf(train_dataset,sGen,wlTrans)


def get_swl_redundency_map(iid_worker_num,ad_worker_num):
    m={}
    for i in range(iid_worker_num):
        m[i]=i
    for i in range(iid_worker_num,iid_worker_num+ad_worker_num):
        m[i]=iid_worker_num
    return m

def get_swl_prior_list(iid_worker_num):
    return [iid_worker_num]+list(range(iid_worker_num))


def get_workersConf(name:Literal["iid","non-iid-class","non-iid-correct","s",
                                 "wl","ewl","swl","sewl"],
                    train_dataset,random_seed,data_num,worker_num,class_num=10,
                    **kwargs):
    # ad_worker_num=int(iid_worker_num/2)
    if name=="iid":
        iidConf=getIidConf(train_dataset,random_seed,data_num)
        return [iidConf]*worker_num
    if name=="non-iid-class":
        alpha=kwargs.get("alpha")
        noniidGen=dataset.ClassDirichletIndices(random_seed,train_dataset,data_num,alpha)
        noniidConf=dataset.WorkerConfNoTrans(train_dataset,noniidGen)
        return [noniidConf]*worker_num
    
    if name=="noisy":
        wrong_ratio=kwargs.get("wrong_ratio")
        get_wrong_ratio=lambda worker_id: wrong_ratio
        conf=getNoisyConf(train_dataset,random_seed,data_num,get_wrong_ratio)
        return [conf]*worker_num
    
    if name=="non-iid-correct":
        exp_wrong_ratio=kwargs.get("exp_wrong_ratio")
        alpha=kwargs.get("alpha")
        a_list=[exp_wrong_ratio*alpha,(1-exp_wrong_ratio)*alpha]
        def get_wrong_ratio(worker_id):
            rs=util.get_random_state(random_seed,worker_id)
            proportion = rs.dirichlet(a_list, size=None)
            wrong_ratio=proportion[0]
            return wrong_ratio
        conf=getNoisyConf(train_dataset,random_seed,data_num,get_wrong_ratio)
        return [conf]*worker_num
    
    if name=="s":
        iidGen=dataset.IIDIndices(random_seed,data_num,indexNum=len(train_dataset),replace=True)
        data_index=iidGen.getIndices(0)
        sGen=dataset.FixIndices(data_index)
        conf=dataset.WorkerConfNoTrans(train_dataset,sGen)
        return [conf]*worker_num
    
    if name=="wl":
        wlConf=getWlConf(train_dataset,random_seed,data_num)
        return [wlConf]*worker_num
    
    if name=="ewl":
        ewlConf=getEwlConf(train_dataset,random_seed,data_num)
        return [ewlConf]*worker_num
    if name=="swl":
        iidGen=dataset.IIDIndices(random_seed,data_num,indexNum=len(train_dataset),replace=True)
        swlConf=getSwlConf(train_dataset,iidGen)
        return [swlConf]*worker_num
    if name=="sewl":
        elGen=get_elGen(train_dataset,random_seed,data_num,class_num=class_num)
        swlConf=getSwlConf(train_dataset,elGen)
        return [swlConf]*worker_num
    
    if name=="sbackdoor":
        wlConf=getWlConf(train_dataset,random_seed,data_num)
        return [wlConf]*worker_num

def get_worker_dataset_from_conf(conf)->dataset.AggWorkersDatasetFromConf:
    workersConf=dataset.WorkersConf(conf)
    workersDataset=dataset.AggWorkersDatasetFromConf(workersConf,None)
    return workersDataset

def get_ref_dataset_from_conf(conf):
    workersDataset=get_worker_dataset_from_conf(conf)
    return dataset.WrapDataset(workersDataset,WrapNeglectId())

@util.repr_alias(attr_name=False)
@dataclass
class WorkersArgs:
    name: Literal["iid","non-iid-class","non-iid-correct","noisy","s",
                                 "wl","ewl","swl","sewl","concat"]
    data_num: int
    worker_num:int

    def get_conf(self,train_dataset,random_seed,class_num=10):
        return get_workersConf(train_dataset=train_dataset,random_seed=random_seed,class_num=class_num,**asdict(self))
    

@util.repr_alias(attr_name=False)
@dataclass
class NonIidClassWorkersArgs(WorkersArgs):
    name:str = field(default="non-iid-class", init=False)
    # name:str = field(default="non-iid-class", init=False,repr=False)
    alpha:float

@util.repr_alias(attr_name=False)
@dataclass
class NonIidCorrectWorkersArgs(WorkersArgs):
    name:str = field(default="non-iid-correct", init=False)
    alpha:float
    exp_wrong_ratio:float

@util.repr_alias(attr_name=False)
@dataclass
class NoisyWorkersArgs(WorkersArgs):
    name:str = field(default="noisy", init=False)
    wrong_ratio:float


@util.repr_alias(attr_name=False)
@dataclass
class NBWorkersArgs():
    normal_args: WorkersArgs
    byzantine_args: WorkersArgs

    def get_conf(self):
        normal_num=self.normal_args.worker_num
        byzantine_num=self.byzantine_args.worker_num
        return NormalByzantineConf(normal_num,byzantine_num)


    def get_dataset(self,train_dataset,random_seed,wrap_method:Literal["default","discrim"],class_num=10):
        normal_num=self.normal_args.worker_num
        byzantine_num=self.byzantine_args.worker_num
        normal_conf=self.normal_args.get_conf(train_dataset,random_seed,class_num=class_num)
        byzantine_conf=self.byzantine_args.get_conf(train_dataset,random_seed,class_num=class_num)
        workersConf=dataset.WorkersConf(normal_conf+byzantine_conf)

        wrapper=create_wrapper(wrap_method)

        if self.byzantine_args.name in ["swl","sewl"]:
            workersDataset=dataset.AggWorkersDatasetWithRedundancy(
                workersConf,wrapper,
                get_swl_redundency_map(normal_num,byzantine_num))
        else:
            workersDataset=dataset.AggWorkersDatasetFromConf(workersConf,wrapper)
            
        return workersDataset
    

    def get_byzantine_set(self,train_dataset,random_seed,class_num):
        args2=self.byzantine_args
        conf=args2.get_conf(train_dataset,random_seed,class_num)
        if args2.name in ["swl","sewl"]:
            return get_ref_dataset_from_conf(conf[:1])
        else:
            return get_ref_dataset_from_conf(conf)
