from .comlib import *

FOLDER="./tests_result/rl_mnist/items3"

def conduct_rl_mnist(save_folder,load_mnist_data,weightUpdaterArg:WeightUpdaterArg,nBWorkersArgs:NBWorkersArgs,trainFLArg:TrainFLArg,grad_args,random_seed="test"):
    device="cuda:2"
    train_dataset, test_dataset=load_mnist_data
    batch_size,value_batch_size=4000,4000
    mnistTask=task.MnistTask()
    eval_datasets={"train":train_dataset,
                   "test":test_dataset,
                   "byz":nBWorkersArgs.get_byzantine_set(train_dataset,random_seed,mnistTask.class_num)}
    # save_folder=f"{FOLDER}/{test_name}"
    # os.makedirs(save_folder, exist_ok=True)
    nbworkers=nBWorkersArgs.get_conf()

    discriminator_v1_1.weightUpdateThenTrain(
        weightUpdaterArg,trainFLArg,
        mnistTask,device,random_seed,save_folder,train_dataset, eval_datasets,
        nBWorkersArgs,batch_size,value_batch_size)

    worker_dataset=nBWorkersArgs.get_dataset(train_dataset,random_seed,'default',mnistTask.class_num)
    for grad_arg in grad_args:
        robust_grad_aggregation.robustfedlearn(trainFLArg,grad_arg,
                    mnistTask,device,random_seed,save_folder,repr(grad_arg),
                eval_datasets,worker_dataset,
                nbworkers,batch_size=4000,seg_len=100)
    

    # discriminator_v1_1.plot_weight_update(save_folder,nbworkers)
    # pas=[PlotArgs("train_phase","fl"),]
    # for grad_arg in grad_args:
    #     pas.append(PlotArgs(repr(grad_arg),"fl"))
    #     pas.append(PlotArgs(repr(grad_arg),"grad_agg"))
    # fss=[pa.get(nbworkers) for pa in pas]
    # save_log.plotFigStrategies(save_folder,fss)
    # plot()
def class_name_only(obj):
    return obj.__class__.__name__

def dump_dataclass(dc):
    return json.dumps(asdict(dc), indent=2, default=class_name_only)        


def get_string_hash(s:str):
    h = hashlib.md5(s.encode()).hexdigest()[:6]
    return h


    
class FigRlMnist2(save_log.FigStrategy):
    def __init__(self, save_name,nbworkers,xs):
        super().__init__(save_name)
        self.fig_names=["train",
                        "test","byz",
                        "normal",
                        "byzantine",]
        self.xs=xs
        # [0.3*i for i in range(1,4)]
        self.nbworkers=nbworkers

    @staticmethod
    def get_val(df,name,cri_name,nbworkers):
        df_last=df.iloc[-1, :]
        if name in ["train","test","byz"]:
            return df_last.loc[f"{cri_name}_{name}"]
        if name in ["normal","byzantine"]:
            ids=nbworkers.get_ids(name)
            pre=f"{cri_name}_worker"
            col=[f"{pre}_{i}" for i in ids]
            return df_last.loc[col].mean()
        
    @staticmethod
    def get_subfolders(folder,path=True):
        subfolders=os.listdir(folder)
        if path:
            subfolders=[os.path.join(folder,name) for name in subfolders]
        return subfolders
    @staticmethod
    def get_line_names(folder):
        line_names1=list(Path(folder).glob('[[]*.csv'))
        line_names1=[os.path.splitext(os.path.basename(line_name))[0] for line_name in line_names1]
        line_names1=[line_name for line_name in line_names1 if 'ifagg' not in line_name]

        line_names2=list(Path(folder).glob('train_phase.csv'))
        line_names2=[os.path.splitext(os.path.basename(line_name))[0] for line_name in line_names2]
        return line_names1+line_names2
    
    
    def get_fig(self,save_folder,fig_name): 
        lines=collections.defaultdict(lambda : None)
        for x in self.xs:
            sub_folder=os.path.join(save_folder,f"br{x:.1e}")
            print(sub_folder)
            line_names=self.get_line_names(sub_folder)
            for line_name in line_names:
                print(line_name)
                df=self.get_df(sub_folder,f"{line_name}")
                if len(df) == 0:
                    break
                val = self.get_val(df,fig_name,"acc",self.nbworkers)
                if lines[line_name] is None:
                    lines[line_name]=save_log.Line([],[])
                lines[line_name].append_x(x)
                lines[line_name].append_y(val)
        return lines
    
    def get_figs(self,save_folder):
        figs={}
        for fig_name in self.fig_names:
            figs[fig_name]=self.get_fig(save_folder,fig_name)
        return figs
    

class FigRlMnist2EB(FigRlMnist2):
    def __init__(self, save_name,nbworkers,xs):
        super().__init__(save_name,nbworkers,xs)

    def get_ys(self,folder,fig_name):
        y_lists=collections.defaultdict(list)
        sub_folders=self.get_subfolders(folder,path=True)
        for sub_folder in sub_folders:
            line_names=self.get_line_names(sub_folder)
            # print(line_names)
            for line_name in line_names:
                df=self.get_df(sub_folder,f"{line_name}")
                if len(df) == 0:
                    break
                val = self.get_val(df,fig_name,"acc",self.nbworkers)
                y_lists[line_name].append(val)
        return y_lists
    
    @staticmethod
    def get_sub_folder_by_x(p_folder,x):
        sub_folders = p_folder.glob(f'*{x:.1e}*')
        for sub_folder in sub_folders:
            return sub_folder
            
    def get_fig(self,save_folder,fig_name): 
        name_map={
            "['geo_median_w'-True]":"geo_median",
            "['iterative_filtering'-True]":"iterative_filtering",
            "['krum'-True]":"krum",
            "train_phase":"this work",
        }
        lines=collections.defaultdict(lambda : None)
        for x in self.xs:
            sub_folder=self.get_sub_folder_by_x(save_folder,x)
            # os.path.join(save_folder,f"al{x:.1e}")
            print(sub_folder)
            y_lists=self.get_ys(sub_folder,fig_name)
            for line_name,y_list in y_lists.items():
                line_name=name_map[line_name]
                if lines[line_name] is None:
                    lines[line_name]=save_log.LineErrorbar([],[],[],'se')
                lines[line_name].append_x(x)
                lines[line_name].append_y(y_list)  
    
            
        return lines
    
    def get_figs(self,save_folder):
        figs={}
        for fig_name in self.fig_names:
            figs[fig_name]=self.get_fig(save_folder,fig_name)
        return figs


    
def get_sub_folders(p_folder):
    folders=[os.path.join(p_folder,name) for name in os.listdir(p_folder)]
    folders=[Path(folder) for folder in folders if os.path.isdir(folder)]
    return folders

def test_rl_mnist2_plot_compare_(p_folder,xs):
    print(p_folder)
    # p_folder="./tests_result/tests_rl/test_conduct_rl_mnist_parallel2_3_2/20250801_161616"
    # folders = [os.path.join(root, d) for root, dirs, _ in os.walk(p_folder) for d in dirs]
    folders=get_sub_folders(p_folder)
    print(folders,len(folders))
    for folder in folders:
        s_folders=get_sub_folders(folder)
        with open(f"{s_folders[0]}/config.json", encoding='utf-8') as f:
            config = json.load(f)   # -> list[dict]
        nBWorkersArgs=from_dict(NBWorkersArgs,config[0])
        nbworkers=nBWorkersArgs.get_conf()
        # worker_with_byzantine.NormalByzantineConf(200,100)
        save_log.plotFigStrategies(folder,[FigRlMnist2EB("",nbworkers,xs)])
        save_log.dumpFigStrategies(folder,[FigRlMnist2EB("",nbworkers,xs)])

def test_rl_mnist2_plot_compare(name="poison_rate"):
    if (name == "poison_rate"):
        xs=[x*0.3 for x in [1,2,3]]
    elif name == "data_num":
        xs=[30*i for i in range(1,5)]
    p_folder=os.path.join("./tests_result",name)
    test_rl_mnist2_plot_compare_(p_folder,xs)


    
def run_with_oom_wait(func, max_retry=12, base_sec=2, **kwargs):
    """指数退避直到显存足够"""
    for attempt in range(max_retry):
        try:
            return func(**kwargs)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                torch.cuda.empty_cache()
                sleep = base_sec * 2 ** attempt + random.uniform(0, 1)
                print(f"[{os.getpid()}] OOM, sleep {sleep:.1f}s")
                time.sleep(sleep)
            else:
                raise
    raise RuntimeError("Still OOM after retries")

