from .comlib import *
from .discriminator_strategy import StrategySigmoidValue,LinearGrad,LinearCEGrad,LinearSigmoidGrad

from .max_variance import MaxVariance
from .known_byzantine import KnownByzantine



# 定义工厂函数
def create_strategy(strategy_name: str, **kwargs) -> StrategySigmoidValue:
    batch_size=kwargs.get("batch_size",400)
    if strategy_name == "mvs":
        dg_strategy=LinearSigmoidGrad(batch_size)
        permutation=kwargs.get("permutation",0)
        value_batch_size=kwargs.get("value_batch_size",batch_size)
        return MaxVariance(dg_strategy,permutation,batch_size)
    
    elif strategy_name == "mvce":
        dg_strategy=LinearCEGrad(batch_size)
        permutation=kwargs.get("permutation",0)
        value_batch_size=kwargs.get("value_batch_size",batch_size)
        return MaxVariance(dg_strategy,permutation,value_batch_size)
    
    elif strategy_name == "kbs":
        dg_strategy=LinearSigmoidGrad(batch_size)
        worker_num=kwargs.get("worker_num")
        byzantine_ids=kwargs.get("byzantine_ids")
        return KnownByzantine(dg_strategy,worker_num,byzantine_ids)
    
    elif strategy_name == "kbce":
        dg_strategy=LinearCEGrad(batch_size)
        worker_num=kwargs.get("worker_num")
        byzantine_ids=kwargs.get("byzantine_ids")
        return KnownByzantine(dg_strategy,worker_num,byzantine_ids)
    else:
        raise ValueError("Unknown strategy name")