from .comlib import *
from .worker_model import WorkerModelSigmoid
# import matplotlib.pyplot as plt
# import seaborn as sns

class Discriminator(fed_learning.FedLearn11):
    def __init__(self, evalObjs, save_folder, max_epoch, chosenWorkerNum, grad_strategy,
                 modelSetup:training.ModelSetup,
                 optimizer:training.Optimizer,
                 worker_dataset:dataset.AggWorkersDatasetFromConf, 
                 value_batch_size,
                 save_round_inteval=None, save_epoch_inteval=None):
        super().__init__(evalObjs, save_folder, max_epoch, chosenWorkerNum, grad_strategy,"dis0", save_round_inteval, save_epoch_inteval)
        self.modelSetup=modelSetup
        self.optimizer=optimizer
        self.worker_dataset=worker_dataset
        self.value_batch_size=value_batch_size
        
    
    def set_name(self,save_name):
        self.saveObj=training.SaveTrain11(self.saveObj.save_folder,self.get_header(self.evalObjs),initialize=True,save_name=save_name)

    def reset_model_optimizer(self,flag_model=True,flag_optimizer=True):
        if flag_model:
            self.modelSetup.initialize_model()
        if flag_optimizer:
            self.optimizer.initialize_optimizer()


    def getWorkerValues(self,modelSetup,worker_dataset):
        worker_num=worker_dataset.worker_num
        workerModel=WorkerModelSigmoid(modelSetup,worker_dataset)
        workerValues=workerModel.getValue(list(range(worker_num)),batch_size=self.value_batch_size)
        return workerValues

    def get_std_worker_values(self,modelSetup,worker_dataset):
        values=self.getWorkerValues(modelSetup,worker_dataset)
        std=torch.std(values,correction=0)
        return {"std":std,"values":values}
    
    def save_model_stat(self,modelSetup,flag):
        if flag:
            self.saveObj.add_data_to_row({"weight_norm":self.get_weight_norm(modelSetup)})
            for key in ["worker"]:
                # r=fed_learning.Eval.get_dict_from_values(
                    # values,["sigmoid"],key,list(range(len(values))))
                r=self.evalObjs[key].get(modelSetup)
                self.saveObj.add_data_to_row(r)

    def train(self,worker_dataset,modelSetup:training.ModelSetup,
          optimizer):
        history=[]
        worker_num=worker_dataset.worker_num
        round_per_epoch=int(worker_num/self.get_chosenWorkerNum(worker_num))

        for epoch in range(self.max_epoch):
            for round in range(round_per_epoch):
                flag=self.get_save_flag(epoch,round_per_epoch,round)
                
                if flag:
                    temp=self.get_std_worker_values(modelSetup,worker_dataset)
                    history.append(temp)
                self.save_model_stat(modelSetup,flag)


                chosen_workers=self.get_chosen_workers(worker_num)
                grad=self.grad_strategy.get(chosen_workers,worker_dataset,modelSetup,
                                            self.get_save_func("grad_agg",flag))

                self.save_grad_norm(grad,flag)

                optimizer.stepByGradTuple(grad)
                self.saveObj.save_cur_row()

        return history


    def maximize(self,chosen_workers):  
        chosen_workers_dataset=self.worker_dataset.getSubWokerDataset(chosen_workers)
        history=self.train(chosen_workers_dataset,self.modelSetup,self.optimizer)
        return history


    # def getWorkerValues_vb_1(self,chosen_workers:list):
    #     workerModel=WorkerModelSigmoid(self.modelSetup,self.worker_dataset)
    #     workerValues=workerModel.getValue(chosen_workers,batch_size=self.value_batch_size)
    #     return workerValues
    




    
class PlotEval(fed_learning.PlotEval):
    def __init__(self, nbworkers:worker_with_byzantine.NormalByzantineConf):
        self.nbworkers=nbworkers
        super().__init__(nbworkers.worker_num)
        self.criterions_unreduced={
            # "loss":nn.CrossEntropyLoss(reduction="none"),
            "sigmoid":nn.Sigmoid(),
        }
        self.dataset_name=["worker"]

    def get_evalObjs(self,workers_dataset,batch_size):        
        evalObjs={}
        evalObjs["worker"]=fed_learning.Eval(workers_dataset,"worker",self.criterions_unreduced,batch_size)
        return evalObjs
    
    # def get_loss_fig(self):
    #     self.discriminator_strategy.getInstantLoss()
    @staticmethod
    def get_df(save_folder,save_name):
        saveObj=training.SaveTrain11(save_folder,initialize=False,save_name=save_name)
        df=saveObj.read_to_df()
        return df

    def get_values_fig(self,df,x_name,cri_name="sigmoid"):
        lines={}
        for line_name in ["normal","byzantine"]:
            ids=self.nbworkers.get_ids(line_name)
            lines[line_name]=fed_learning.Eval.get_line(df,x_name,cri_name,"worker",ids)
        return lines
    
    @staticmethod
    def get_scores_from_values(df_values):

        # DataFrame -> 二维 np.ndarray
        arr_values = df_values.to_numpy()
        centered_values=util.centerize(arr_values,dim=1,keepdims=True)
        scores=centered_values**2

        # 二维 np.ndarray -> DataFrame
        df_scores = pd.DataFrame(scores, columns=df_values.columns, index=df_values.index)

        return df_scores
    
    def get_std_scores_fig(self,df,x_name,cri_name="sigmoid"):
        figs={
            "std":{},
            "scores":{},
            }
        h=fed_learning.Eval.get_header_cri_name(cri_name,"worker",self.nbworkers.get_ids("worker"))
        df_values=df.loc[:,h]
        df_scores=self.get_scores_from_values(df_values)

        figs["std"]["line"]=save_log.Line(df.loc[:,x_name],df_values.std(axis=1, ddof=0))
        for line_name in ["normal","byzantine"]:
            ids=self.nbworkers.get_ids(line_name)
            h=fed_learning.Eval.get_header_cri_name(cri_name,"worker",ids)
            line=save_log.LineErrorbar(df.loc[:,x_name])
            line.setYFromdf(df_scores.loc[:,h])
            figs["scores"][line_name]=line
        return figs

    def get_figs(self,save_folder,save_name):
        x_name="round"
        figs={"weight_norm":{},
            "grad_norm":{},
            # "loss":{},
            "std":{},
            "values":{},
            "scores":{},
            }
        df=self.get_df(save_folder,save_name)

        for fig_name in ["weight_norm","grad_norm"]:
            figs[fig_name][0]=save_log.get_line_from_df(df,fig_name,x_name)

        
        figs["values"]=self.get_values_fig(df,x_name,"sigmoid")
        figs.update(self.get_std_scores_fig(df,x_name,"sigmoid"))

        return figs
    
    def plot(self,save_folder,save_name):
        figs=self.get_figs(save_folder,save_name)
        for name in figs:
            fig=figs[name]
            with open(f"{save_folder}/{save_name}_{name}.pkl", 'wb') as f:
                pickle.dump(dict(fig),f)
            save_log.plot1(fig,f"{save_folder}/{save_name}_{name}.png")


