from .comlib import *
from ..training import criterion,ModelSetup
from .worker_model import WorkerModel
from .worker_model_v1_1 import WorkerModelR


class Eval(training.Eval11):
    def __init__(self, dataset, dataset_name, criterions, batch_size,worker_ids=None):
        super().__init__(dataset, dataset_name, criterions, batch_size)
        self.worker_ids=worker_ids
        if worker_ids is None:
            worker_num=self.dataset.worker_num
            self.worker_ids=list(range(worker_num))

    @staticmethod
    def get_dict_from_values(values,criterion_names,dataset_name,worker_ids):
        r_dict={}
        for i,cri_name in enumerate(criterion_names):
            h=Eval.get_header_cri_name(cri_name,dataset_name,worker_ids)
            value_list=values[i].cpu().tolist()
            r_dict.update(dict(zip(h, value_list)))
        return r_dict


    def get(self,modelSetup:ModelSetup):
        criterion_key, criterion_list = zip(*self.criterions.items())
        worker_num=self.dataset.worker_num
        workerModel=WorkerModelR(modelSetup,self.dataset)
        values=workerModel.getChosenWorkersValues(criterion_list,list(range(worker_num)),self.batch_size)

        r_dict={}
        for i,cri_name in enumerate(criterion_key):
            h=self.get_header_cri_name(cri_name,self.dataset_name,self.worker_ids)
            value_list=values[i].cpu().tolist()
            r_dict.update(dict(zip(h, value_list)))

        
        # for i,cri_name in enumerate(criterion_key):
        #     temp_value=values[i]
        #     for j in range(len(temp_value)):
        #         r_dict[f"{cri_name}_{self.dataset_name}_{j}"]=temp_value[j].item()

        return r_dict
    
    @staticmethod
    def get_header_cri_name(cri_name,dataset_name,worker_ids):
        h=[f"{cri_name}_{dataset_name}_{j}" for j in worker_ids]
        return h

    
    def get_header(self):
        # ,worker_num,criterions,dataset_name="worker"
        l=[]
        for cri_name in self.criterions:
            l=l+self.get_header_cri_name(cri_name,self.dataset_name,self.worker_ids)
            # l=l+[f"{cri_name}_{self.dataset_name}_{j}" for j in range(worker_num)]
        return l
    
    @staticmethod
    def get_line(df,x_name,cri_name,dataset_name="worker",worker_ids=None):
        h=Eval.get_header_cri_name(cri_name,dataset_name,worker_ids)
        line=save_log.LineErrorbar(df.loc[:,x_name])
        line.setYFromdf(df.loc[:,h])
        return line

    

class PlotEval(training.PlotEval):
    def __init__(self,worker_num):
        super().__init__()
        self.dataset_name=["train","test","worker"]
        self.worker_num=worker_num

    def get_evalObjs(self,train_dataset,test_dataset,workers_dataset:dataset.AggWorkersDatasetFromConf,batch_size):
        evalObjs=super().get_evalObjs(train_dataset,test_dataset,batch_size)
        evalObjs["worker"]=Eval(workers_dataset,"worker",self.criterions_unreduced,batch_size)
        return evalObjs
    def get_worker_ids(self):
        return list(range(self.worker_num))
    
    @staticmethod
    def get_df(save_folder,save_name):
        save_file=f"{save_folder}/{save_name}.csv"
        saveCsvObj=save_log.SaveCsvHeader(save_file,None)
        df=saveCsvObj.read_to_df()
        return df
    
    def get_figs(self,save_folder,save_name):
        x_name="round"
        figs={"loss":{},"acc":{},"out_norm":{},
            "weight_norm":{},
            "grad_norm":{}}
        df=self.get_df(save_folder,save_name)
        for cri_name in self.criterions:
            lines={}
            for dataset_name in ["train","test"]:
                line_name=dataset_name
                lines[line_name]=training.Eval11.get_line(df,x_name,cri_name,dataset_name)
            lines["worker"]=Eval.get_line(df,x_name,cri_name,"worker",self.get_worker_ids())
            figs[cri_name]=lines

        for fig_name in ["weight_norm","grad_norm"]:
            figs[fig_name][0]=save_log.get_line_from_df(df,fig_name,x_name)
        return figs
    
    def plot(self,save_folder,save_name="train"):
        figs=self.get_figs(save_folder,save_name)
        for name in figs:
            save_log.plot1(figs[name],f"{save_folder}/{name}.png")

# def get_evalObjs_fl(train_dataset,test_dataset,workers_dataset:dataset.AggWorkersDatasetFromConf,batch_size):
#     evalObjs=training.get_evalObjs(train_dataset,test_dataset,batch_size)

#     criterions={
#         "loss":nn.CrossEntropyLoss(reduction="none"),
#         "acc":criterion.accuracy(reduction="none")
#     }
#     evalObjs["worker"]=Eval(workers_dataset,"worker",criterions,batch_size)
#     return evalObjs





# class Save():
#     def __init__(self,dataset_name,criterions:dict):
#         self.dataset_name=dataset_name
#         self.criterions=criterions
#     @staticmethod
#     def get_header(worker_num,criterions,dataset_name):
#         l=[]
#         for cri_name in criterions:
#             l=l+[f"{cri_name}_{dataset_name}_{j}" for j in range(worker_num)]
#         return l

# criterions={
#     "loss":get_workers_criterion(nn.CrossEntropyLoss(reduction="none"),worker_num),
#     "acc":get_workers_criterion(criterion.accuracy(reduction="none"),worker_num)
# }

# def get_workers_criterion(criterion_unreduced,worker_num):
#     def f(out,worker_ids,*out_args):
#         return scatterValue(criterion_unreduced,out,out_args,worker_ids,worker_num)
#     return f

# def scatterValue(criterion_unreduced,out,out_args,worker_ids,worker_num):
#     device=out.device
#     loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
#     loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=device)
#     loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

#     return loss_worker_vec

    # @staticmethod
    # def get_line(df,x_name,cri_name,dataset_name,normal_num,byzantine_num):
    #     worker_num=normal_num+byzantine_num
    #     ids={}
    #     ids["normal"]=list(range(normal_num))
    #     ids["byzantine"]=list(range(normal_num,normal_num+byzantine_num))
    #     lines={}
    #     for i in ["normal","byzantine"]:
    #         h=[f"{cri_name}_{dataset_name}_{j}" for j in ids[i]]
    #         line=save_log.LineErrorbar(df.loc[:,x_name])
    #         line.setYFromdf(df.loc[:,h])
    #         lines[i]=line
    #     return lines

# def getLine(line_name=Literal["mean","var","std","weighted"]):   