from .comlib import *

from .discriminator_strategy import LinearCEGrad
from .worker_model import WorkerModelSigmoid
from .discriminator_ import Discriminator,PlotEval

from .max_variance import MaxVariance
from .known_byzantine import KnownByzantine

class GradStrategy():
    def __init__(self,weight_strategy,linear_strategy:fed_learning.LinearGrad):
        self.weight_strategy=weight_strategy
        self.linear_strategy=linear_strategy
    def get(self,chosen_workers,worker_dataset,modelSetup,*args):
        workerModel=WorkerModelSigmoid(modelSetup,worker_dataset)
        weight=self.weight_strategy.get(chosen_workers,workerModel)
        weight=weight/len(chosen_workers)
        grad=self.linear_strategy.get(weight,chosen_workers,worker_dataset,modelSetup)
        return grad


@dataclass
class WeightArgs():
    name: Literal["mv","kb"]
    def get(self,*args):
        pass

@util.repr_alias(attr_name=False)
@dataclass
class MvArgs():
    name:str= field(default="mv", init=False)
    permutation:float=0
    value_batch_size:int=4000
    def get(self,*args):
        weight_strategy=MaxVariance(self.permutation,self.value_batch_size)
        return weight_strategy

@util.repr_alias(attr_name=False)   
@dataclass
class KbArgs():
    name:str= field(default="kb", init=False)
    def get(self,nbworkers):
        weight_strategy=KnownByzantine(nbworkers.worker_num,nbworkers.get_ids("byzantine"))
        return weight_strategy

@util.repr_alias(attr_name=False)   
@dataclass
class StrategyArgs():
    weight_args: WeightArgs
    linear_name: Literal["s","ce"]
    
    def get(self,batch_size:int,nbworkers=None):
        weight_strategy=self.weight_args.get(nbworkers)
        linear_strategy=create_linear_strategy(self.linear_name,batch_size)
        return GradStrategy(weight_strategy,linear_strategy)

def create_linear_strategy(linear_name:Literal["s","ce"],batch_size):
    if linear_name == "s":
        linear_strategy=fed_learning.LinearGradR(nn.Sigmoid(),batch_size)
    elif linear_name == "ce":
        linear_strategy=LinearCEGrad(batch_size)
    else:
        print("Unknown strategy name")
        raise ValueError("Unknown strategy name")
    return linear_strategy


# 定义工厂函数
def create_strategy(weight_name:Literal["mv","kb"],
                    linear_name:Literal["s","ce"],
                    batch_size, **kwargs) -> GradStrategy:
    batch_size=kwargs.get("batch_size",400)
    if weight_name == "mv":
        permutation=kwargs.get("permutation",0)
        value_batch_size=kwargs.get("value_batch_size",batch_size)
        weight_strategy=MaxVariance(permutation,value_batch_size)
    elif weight_name == "kb":
        nbworkers=kwargs.get("nbworkers")
        nbworkers:worker_with_byzantine.NormalByzantineConf
        weight_strategy=KnownByzantine(nbworkers.worker_num,nbworkers.get_ids("byzantine"))
    else:
        print("Unknown strategy name")
        raise ValueError("Unknown strategy name")
    
    linear_strategy=create_linear_strategy(linear_name,batch_size)
    
        
    return GradStrategy(weight_strategy,linear_strategy)
    

    
def get_modelSetup(t,device,random_seed,**kwargs):
    if isinstance(t,task.MnistTask):
        model=task.create_model("embed",t,**kwargs)
        modelSetup=training.ModelSetup(device,model,task.Initialize(random_seed))
    return modelSetup
    
def get_optimizer(modelSetup,linear_name,learning_rate):
    if linear_name=="s":
        return training.Optimizer("Adam",modelSetup,learning_rate,maximize=True)
    if linear_name=="ce":
        return training.Optimizer("Adam",modelSetup,learning_rate)

def get_worker_dataset(workersConf,redundency_map=None):
    if redundency_map is not None:
        workersDataset=dataset.AggWorkersDatasetWithRedundancy(
            workersConf,fed_learning.WrapWorkersDiscriminator(),
            redundency_map)
    else:
        workersDataset=dataset.AggWorkersDatasetFromConf(workersConf,fed_learning.WrapWorkersDiscriminator())

    return workersDataset


def create_discriminator(nbworkers,weight_name,linear_name,batch_size,grad_strategy_args,
        save_folder,max_epoch,chosenWorkerNum,modelSetup,optimizer,worker_dataset,value_batch_size,
                 save_round_inteval=None, save_epoch_inteval=None,**kwargs):
    plotEval=PlotEval(nbworkers)
    evalObjs=plotEval.get_evalObjs(worker_dataset,value_batch_size)
    grad_strategy=create_strategy(weight_name,linear_name,batch_size,nbworkers=nbworkers,value_batch_size=value_batch_size,**grad_strategy_args)

    discriminator=Discriminator(evalObjs, save_folder, max_epoch, chosenWorkerNum, grad_strategy,
                 modelSetup,
                 optimizer,
                 worker_dataset, 
                 value_batch_size,
                 save_round_inteval=save_round_inteval, save_epoch_inteval=save_epoch_inteval)
    
    return discriminator

def create_weightUpdater():
    ...

def plot_discriminator(save_folder,nbworkers):
    plotEval=PlotEval(nbworkers)
    plotEval.plot(save_folder)