from .. import dataset,util
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from typing import Literal
from dataclasses import dataclass,asdict,field
from .conf import NormalByzantineConf
from .wrapper import create_wrapper
from .workers import get_swl_redundency_map,WorkersArgs,get_ref_dataset_from_conf

class WrapNeglectId():
    def __init__(self):
        pass
    @staticmethod
    def wrap(worker_ids,data,target):
        return data,target



@util.repr_alias(attr_name=False)
@dataclass
class ConcatArgs(WorkersArgs):
    name:str=field(default="concat",init=False)
    args1_name:str
    args2_name:str
    args2_ratio:float

    def get_args2(self):
        data_num2=int(self.data_num*self.args2_ratio)
        args2=WorkersArgs(self.args2_name,data_num2,1)
        return args2

    def get_conf(self,train_dataset,random_seed,class_num=10,ref_dataset=None):
        if ref_dataset is None:
            ref_dataset=train_dataset
        data_num2=int(self.data_num*self.args2_ratio)
        data_num1=self.data_num-data_num2
        
        # ref_dataset=get_subset_from_workerdataset(random_seed,ref_dataset,data_num1)
        # dataset_operation.Concat_Dataset([ds1,ds2])
        conf1=WorkersArgs(self.args1_name,data_num1,1).get_conf(ref_dataset,random_seed,class_num)[0]
        conf2=WorkersArgs(self.args2_name,data_num2,1).get_conf(train_dataset,random_seed,class_num)[0]
        conf=dataset.MixConf(random_seed,
                    conf1,
                    conf2,
                    dataset.MixConf.CONCAT)
        return [conf]*self.worker_num
    

@util.repr_alias(attr_name=False)
@dataclass
class NBWorkersArgsRef():
    normal_args: WorkersArgs
    byzantine_args: WorkersArgs|ConcatArgs
    ref_workernum:int|None=field(default=None)

    def get_conf(self):
        normal_num=self.normal_args.worker_num
        byzantine_num=self.byzantine_args.worker_num
        return NormalByzantineConf(normal_num,byzantine_num)
    
    @staticmethod
    def get_subset_from_workerdataset(random_seed,workerdataset,data_num):
        iidGen=dataset.IIDIndices(random_seed+"concat",data_num,indexNum=len(workerdataset),replace=True)
        refdataset=Subset(workerdataset,iidGen.getIndices(0))
        ds=dataset.WrapDataset(refdataset,WrapNeglectId())
        return ds
        # return dataset.WorkerConfSimple(ds)
    
    @staticmethod
    def get_dataset_from_conf(normal_conf)->dataset.AggWorkersDatasetFromConf:
        workersConf=dataset.WorkersConf(normal_conf)
        workersDataset=dataset.AggWorkersDatasetFromConf(workersConf,None)
        return workersDataset

    
    def get_ref_dataset(self,normal_conf):
        workersDataset=NBWorkersArgsRef.get_dataset_from_conf(normal_conf)
        chosendataset=workersDataset.getSubWokerDataset(list(range(self.ref_workernum)))
        return dataset.WrapDataset(chosendataset,WrapNeglectId())


    def get_byzantine_set(self,train_dataset,random_seed,class_num):
        args2=self.byzantine_args.get_args2()
        conf=args2.get_conf(train_dataset,random_seed,class_num)
        if args2.name in ["swl","sewl"]:
            return get_ref_dataset_from_conf(conf[:1])
        else:
            return get_ref_dataset_from_conf(conf)


    def get_dataset(self,train_dataset,random_seed,wrap_method:Literal["default","discrim"],class_num=10):

        wrapper=create_wrapper(wrap_method)

        normal_num=self.normal_args.worker_num
        byzantine_num=self.byzantine_args.worker_num
        normal_conf=self.normal_args.get_conf(train_dataset,random_seed,class_num=class_num)

        ref_dataset=self.get_ref_dataset(normal_conf)
        byzantine_conf=self.byzantine_args.get_conf(
            train_dataset,
            random_seed,class_num=class_num,ref_dataset=ref_dataset)

        workersConf=dataset.WorkersConf(normal_conf+byzantine_conf)
        
        if self.byzantine_args.name in ["swl","sewl"]:
            workersDataset=dataset.AggWorkersDatasetWithRedundancy(
                workersConf,wrapper,
                get_swl_redundency_map(normal_num,byzantine_num))
        else:
            workersDataset=dataset.AggWorkersDatasetFromConf(workersConf,wrapper)
            
        return workersDataset