from typing import Callable, Any, Literal
import torch
from .. import save_log

class NormalByzantineConf():
    def __init__(self, normal_num,byzantine_num):
        self.normal_num=normal_num
        self.byzantine_num=byzantine_num
        self.worker_num=normal_num+byzantine_num


    def get_byzantine_ratio(self):
        return self.byzantine_num/self.worker_num

    def get_ids(self,name:Literal["worker","normal","byzantine"]):
        if name=="worker":
            return list(range(self.worker_num))
        if name=="normal":
            return list(range(self.normal_num))
        if name=="byzantine":
            return list(range(self.normal_num,self.worker_num))
        
    def print_values(self,values):
        if isinstance(values,torch.Tensor):
            values=values.cpu()
        r={}
        for name in ["normal","byzantine"]:
            ids=self.get_ids(name)
            r[f"{name}_std"],r[f"{name}_mean"]=torch.std_mean(values[ids])
        return r


def get_line(df,name,x_name,worker_ids=None):
    h=[f"{name}_{j}" for j in worker_ids]
    line=save_log.LineErrorbar(df.loc[:,x_name])
    line.setYFromdf(df.loc[:,h])
    return line
def get_values_fig(df,name,x_name,nbworkers):
    lines={}
    for line_name in ["normal","byzantine"]:
        ids=nbworkers.get_ids(line_name)
        lines[line_name]=get_line(df,name,x_name,ids)
    return lines