from .comlib import *
from ..training import criterion,ModelSetup
from .worker_model import WorkerModel
from .worker_model_v1_1 import WorkerModelR
from .. import worker_with_byzantine


class Eval(training.Eval11):
    def __init__(self, dataset, dataset_name, criterions, batch_size,worker_ids=None):
        super().__init__(dataset, dataset_name, criterions, batch_size)
        self.worker_ids=worker_ids
        if worker_ids is None:
            worker_num=self.dataset.worker_num
            self.worker_ids=list(range(worker_num))

    @staticmethod
    def get_dict_from_values(values,criterion_names,dataset_name,worker_ids):
        r_dict={}
        for i,cri_name in enumerate(criterion_names):
            h=Eval.get_header_cri_name(cri_name,dataset_name,worker_ids)
            value_list=values[i].cpu().tolist()
            r_dict.update(dict(zip(h, value_list)))
        return r_dict


    def get(self,modelSetup:ModelSetup):
        criterion_key, criterion_list = zip(*self.criterions.items())
        worker_num=self.dataset.worker_num
        workerModel=WorkerModelR(modelSetup,self.dataset)
        values=workerModel.getChosenWorkersValues(criterion_list,list(range(worker_num)),self.batch_size)

        r_dict={}
        for i,cri_name in enumerate(criterion_key):
            h=self.get_header_cri_name(cri_name,self.dataset_name,self.worker_ids)
            value_list=values[i].cpu().tolist()
            r_dict.update(dict(zip(h, value_list)))

        return r_dict
    
    @staticmethod
    def get_header_cri_name(cri_name,dataset_name,worker_ids):
        h=[f"{cri_name}_{dataset_name}_{j}" for j in worker_ids]
        return h
    
    def get_header(self):
        l=[]
        for cri_name in self.criterions:
            l=l+self.get_header_cri_name(cri_name,self.dataset_name,self.worker_ids)
        return l
    
def get_default_criterions():
    eval_criterions={
        "loss":nn.CrossEntropyLoss(),
        "acc":training.accuracy(reduction="mean"),
        "out_norm":training.out_norm
    }
    worker_criterions={
        "loss":nn.CrossEntropyLoss(reduction="none"),
        "acc":training.accuracy(reduction="none"),
        # "out_norm":criterion.out_norm
    }
    return eval_criterions,worker_criterions

def get_evalObjs(eval_datasets:dict,eval_criterions,workers_dataset:dataset.AggWorkersDatasetFromConf,worker_criterions,batch_size):
    # eval_datasets={"train":train_dataset,
    #                     "test":test_dataset,}
        
    evalObjs={}
    for dname, dset in eval_datasets.items():
        evalObjs[dname]=training.Eval11(dset,dname,eval_criterions,batch_size)
    evalObjs["worker"]=Eval(workers_dataset,"worker",worker_criterions,batch_size)
    return evalObjs


class PlotFL(save_log.FigStrategy):
    def __init__(self, 
                 save_name,nbworkers,
                 eval_dataset_names:list=["train","test"],
                 x_name="round"):
        super().__init__(save_name)
        self.nbworkers=nbworkers
        self.x_name=x_name
        self.eval_dataset_names=eval_dataset_names

    def get_figs(self, save_folder):
        x_name=self.x_name
        df=self.get_df(save_folder, self.save_name)
        figs={"loss":{},"acc":{},"out_norm":{},
            "weight_norm":{},
            "grad_norm":{}}
        
        for cri_name in ["loss","acc"]:
            lines={}
            for line_name in self.eval_dataset_names:
                name=f"{cri_name}_{line_name}"
                lines[line_name]=save_log.get_line_from_df(df,name,x_name)
                # training.Eval11.get_line(df,x_name,cri_name,line_name)

            name=f"{cri_name}_worker"
            lines.update(
                worker_with_byzantine.get_values_fig(df,name,x_name,self.nbworkers)
            )
            # print(lines["normal"].x)
            # print('-')
            # print(lines["normal"].y)
            # print(lines["normal"].y_errorbar)
            figs[cri_name]=lines

        for fig_name in ["out_norm"]:
            lines={}
            for line_name in ["train","test"]:
                lines[line_name]=save_log.get_line_from_df(
                    df,f"{fig_name}_{line_name}",x_name)
            figs[fig_name]=lines

        for fig_name in ["weight_norm","grad_norm"]:
            figs[fig_name][0]=save_log.get_line_from_df(df,fig_name,x_name)
        
        return figs


