from .comlib import *
from .discriminator_strategy import LinearCEGrad
from .worker_model import WorkerModelSigmoid
from .discriminator_ import Discriminator,PlotEval

from .max_variance import MaxVariance
from .known_byzantine import KnownByzantine

from .weight_update import WeightUpdater,SaveRemainWorkers
from . import weight_update
from .train_phase import MeanGradWithChosenWorkers

class GradStrategy():
    def __init__(self,weight_strategy,linear_strategy:fed_learning.LinearGrad):
        self.weight_strategy=weight_strategy
        self.linear_strategy=linear_strategy
    def get(self,chosen_workers,worker_dataset,modelSetup,*args):
        workerModel=WorkerModelSigmoid(modelSetup,worker_dataset)
        weight=self.weight_strategy.get(chosen_workers,workerModel)
        weight=weight/len(chosen_workers)
        grad=self.linear_strategy.get(weight,chosen_workers,worker_dataset,modelSetup)
        return grad


@dataclass
class WeightArgs():
    name: Literal["mv","kb"]
    def get(self,value_batch_size=None,nbworkers=None):
        pass

@util.repr_alias(attr_name=False)
@dataclass
class MvArgs():
    name:str= field(default="mv", init=False)
    permutation:float=0
    def get(self,value_batch_size,**kwargs):
        weight_strategy=MaxVariance(self.permutation,value_batch_size)
        return weight_strategy

@util.repr_alias(attr_name=False)   
@dataclass
class KbArgs():
    name:str= field(default="kb", init=False)
    def get(self,nbworkers,**kwargs):
        weight_strategy=KnownByzantine(nbworkers.worker_num,nbworkers.get_ids("byzantine"))
        return weight_strategy

@util.repr_alias(attr_name=False)   
@dataclass
class LinearArgs():
    name: Literal["s","ce"]
    def get_strategy(self,batch_size):
        if self.name == "s":
            linear_strategy=fed_learning.LinearGradR(nn.Sigmoid(),batch_size)
        elif self.name == "ce":
            linear_strategy=LinearCEGrad(batch_size)
        else:
            print("Unknown strategy name")
            raise ValueError("Unknown strategy name")
        return linear_strategy
    
    def get_optimizer(self,modelSetup: training.ModelSetup,name="Adam", lr=0.01,**kwargs):
        if self.name=="s":
            return training.Optimizer(name,modelSetup,lr,maximize=True,**kwargs)
        if self.name=="ce":
            return training.Optimizer(name,modelSetup,lr,**kwargs)

@util.repr_alias(attr_name=False)   
@dataclass
class GradArgs():
    weight_args: WeightArgs
    linear_args: LinearArgs
    
    def get(self,batch_size:int,value_batch_size=None,nbworkers=None):
        weight_strategy=self.weight_args.get(value_batch_size=value_batch_size,nbworkers=nbworkers)
        linear_strategy=self.linear_args.get_strategy(batch_size)
        return GradStrategy(weight_strategy,linear_strategy)



@util.repr_alias(attr_name=False) 
@dataclass 
class DnnArgDiscrim(training.DnnArg):
    name:str=field(default="embed",init=False)
    
def get_modelSetup(t,device,random_seed,**kwargs):
    if isinstance(t,task.MnistTask):
        model=task.create_model("embed",t,**kwargs)
        modelSetup=training.ModelSetup(device,model,task.Initialize(random_seed))
    return modelSetup
    


def get_worker_dataset(workersConf,redundency_map=None):
    if redundency_map is not None:
        workersDataset=dataset.AggWorkersDatasetWithRedundancy(
            workersConf,fed_learning.WrapWorkersDiscriminator(),
            redundency_map)
    else:
        workersDataset=dataset.AggWorkersDatasetFromConf(workersConf,fed_learning.WrapWorkersDiscriminator())

    return workersDataset



def create_discriminator(save_folder,modelSetup,optimizer_args,worker_dataset,nbworkers,fedLearnArgs:fed_learning.FedLearnArgs,gradArgs: GradArgs,batch_size=4000,value_batch_size=None):
    if value_batch_size is None:
        value_batch_size=batch_size
    plotEval=PlotEval(nbworkers)
    evalObjs=plotEval.get_evalObjs(worker_dataset,value_batch_size)
    grad_strategy=gradArgs.get(batch_size,value_batch_size,nbworkers)

    optimizer=gradArgs.linear_args.get_optimizer(modelSetup,**optimizer_args)

    chosenWorkerNum=int(fedLearnArgs.chosenWorkerRatio*worker_dataset.worker_num)
    discriminator=Discriminator(evalObjs, save_folder, fedLearnArgs.max_epoch, chosenWorkerNum, grad_strategy,
                modelSetup,
                optimizer,
                worker_dataset, 
                value_batch_size,
                save_round_inteval=fedLearnArgs.save_round_inteval, save_epoch_inteval=fedLearnArgs.save_epoch_inteval)

    return discriminator
    

def plot_discriminator(save_folder,save_name,nbworkers):
    plotEval=PlotEval(nbworkers)
    plotEval.plot(save_folder,save_name)

def plot_discriminators(save_folder,nbworkers):
    # folder = Path(save_folder)  
    csv_files = list(Path(save_folder) .glob('dis*.csv'))
    # files = [Path(f).stem for f in folder.iterdir() if f.is_file()]
    for file in csv_files:
        save_name=file.stem
        plot_discriminator(save_folder,save_name,nbworkers)

def plot_remain_workers(save_folder,nbworkers):
    save_name="remain_workers"
    fig=weight_update.get_fig(save_folder,save_name,nbworkers)
    save_log.plot1(fig,f"{save_folder}/{save_name}_chosen.png")

def plot_weight_update(save_folder,nbworkers):
    plot_discriminators(save_folder,nbworkers)
    plot_remain_workers(save_folder,nbworkers)


    
@util.repr_alias(attr_name=False)   
@dataclass
class DiscriminatorIFArgs():
    eta:float
    end_condition_ratio_bias:float=field(default=0)
    optimizer_initialize:bool=True
    model_initialize:bool=False

    
@util.repr_alias(attr_name=False)   
@dataclass
class WeightUpdaterArg():
    ifArgs:DiscriminatorIFArgs
    fedArgs:fed_learning.FedLearnArgs
    gradArgs:GradArgs
    modelArg:DnnArgDiscrim
    optimizer_args:dict

    def get(self,t,device,random_seed,save_folder,worker_dataset,nbworkers,batch_size=4000,value_batch_size=None,):
        modelSetup=self.modelArg.get(t,device,random_seed)

        discriminator=create_discriminator(save_folder,modelSetup,self.optimizer_args,worker_dataset,nbworkers,self.fedArgs,self.gradArgs,batch_size,value_batch_size)

        ifDict=asdict(self.ifArgs)
        ifDict["end_condition_ratio"]=nbworkers.get_byzantine_ratio()+ifDict.pop("end_condition_ratio_bias")
        weightUpdater=WeightUpdater(
            modelSetup,worker_dataset, discriminator,
            **ifDict,
            saveObj=SaveRemainWorkers(save_folder,"remain_workers",nbworkers.worker_num))
    
        return weightUpdater


def weightUpdateThenTrain(
        weightUpdaterArg:WeightUpdaterArg,trainArg:fed_learning.TrainFLArg,
        t,device,random_seed,save_folder,train_dataset,eval_datasets,
        nBWorkersArgs:worker_with_byzantine.NBWorkersArgs,batch_size=4000,value_batch_size=None):
    
    nbworkers=nBWorkersArgs.get_conf()
    worker_dataset_discrim=nBWorkersArgs.get_dataset(train_dataset,random_seed,'discrim',t.class_num)
    weightUpdater=weightUpdaterArg.get(t,device,random_seed,save_folder,worker_dataset_discrim,nbworkers,batch_size,value_batch_size)
    remain_workers=weightUpdater.update_weight()

    worker_dataset=nBWorkersArgs.get_dataset(train_dataset,random_seed,'default',t.class_num)
    # remain_workers_dataset=worker_dataset.getSubWokerDataset(remain_workers)
    grad_strategy=MeanGradWithChosenWorkers(trainArg.criterion,batch_size,remain_workers.cpu().tolist())
    trainArg.fedlearn(grad_strategy,t,device,random_seed,save_folder,"train_phase",eval_datasets,
        worker_dataset,batch_size)

def plot_weight_update_then_train(save_folder,nbworkers):
    plot_weight_update(save_folder,nbworkers)
    
    plotEval=fed_learning.PlotEval(nbworkers.worker_num)
    plotEval.plot(save_folder,"train_phase")

# @util.repr_alias(attr_name=False)   
# @dataclass
# class FedLearnArgs():
#     max_epoch:int
#     chosenWorkerRatio:float
#     save_round_inteval:int|None=None
#     save_epoch_inteval:int|None=None

# @util.repr_alias(attr_name=False)   
# @dataclass
# class TrainArg():
#     fedArgs:FedLearnArgs
#     modelArg:DnnArgDiscrim
#     optimizer_args:dict
#     criterion:nn.Module|Callable=nn.CrossEntropyLoss(reduction='none')

#     def train(self,t,device,random_seed,save_folder,train_dataset, test_dataset,
#         worker_dataset:dataset.AggWorkersDatasetFromConf,nbworkers,batch_size=4000):
#         plotEval=fed_learning.PlotEval(nbworkers.normal_num)
#         evalObjs=plotEval.get_evalObjs(train_dataset, test_dataset, worker_dataset,batch_size)

#         fedLearn=fed_learning.FedLearn11(evalObjs, save_folder,grad_strategy=fed_learning.MeanGrad(self.criterion,batch_size),save_name="train",**asdict(self.fedArgs))
#         modelSetup=self.modelArg.get(t,device,random_seed)
#         optimizer=training.Optimizer(**self.optimizer_args)
#         fedLearn.train(worker_dataset,modelSetup,optimizer)