from .comlib import *
from .discriminator_ import Discriminator

class Score():
    def __init__(self,worker_num,device):
        self.worker_num=worker_num
        self.device=device
        self.scores=torch.zeros(self.worker_num,device=self.device)

    def get_scores(self,workers:torch.tensor):
        return self.scores[workers]
    def update(self,workers:torch.tensor,add_value):
        self.scores[workers]=self.scores[workers]+add_value


class WeightUpdater():
    def __init__(self,modelSetup:training.ModelSetup,workerDataset:dataset.AggWorkersDatasetFromConf, discriminator:Discriminator,
                 optimizer_initialize:bool,model_initialize:bool,eta,end_condition_ratio,saveObj):
        self.modelSetup=modelSetup
        self.workerDataset=workerDataset
        # self.save_folder=save_folder
        self.eta=eta
        self.end_condition_ratio=end_condition_ratio

        self.discriminator=discriminator
        self.optimizer_initialize=optimizer_initialize
        self.model_initialize=model_initialize

        self.device=modelSetup.device
        self.worker_num=workerDataset.worker_num

        self.saveObj=saveObj

    def maximizeAndGetValue(self,round,cur_remain_workers:torch.Tensor)->torch.Tensor:
        self.discriminator.set_name(f"dis{round}")
        self.discriminator.reset_model_optimizer(self.optimizer_initialize,self.model_initialize)
        discriminator_history=self.discriminator.maximize(cur_remain_workers)
        temp=max(discriminator_history, key=lambda d: d.get('std', float('-inf')))
        # values=self.discriminator.getWorkerValues(cur_remain_workers.cpu().tolist())
        return temp["values"]

    @staticmethod
    def getScoreFromValue(values:torch.Tensor)->torch.Tensor:
        centered_values=util.centerize(values)
        scores=centered_values**2
        return scores
    
    def getRemainWorkersFromValues(self,values,cur_remain_workers,scores:Score)->torch.Tensor:
        cur_scores=self.getScoreFromValue(values)
        max_cur_scores=torch.max(cur_scores)
        nomalized_cur_scores=cur_scores/max_cur_scores
        # print("nomalized_cur_scores",nomalized_cur_scores)
        scores.update(cur_remain_workers,nomalized_cur_scores)
        remain_workers_score=scores.get_scores(cur_remain_workers)
        return cur_remain_workers[remain_workers_score<self.eta]
    
    def conduct_weight_update_round(self,round,cur_remain_workers:torch.Tensor,scores):
        values=self.maximizeAndGetValue(round,cur_remain_workers)
        remain_workers=self.getRemainWorkersFromValues(values,cur_remain_workers,scores)
        self.saveObj.save_remain_set(round,remain_workers.cpu().tolist())
        return remain_workers

    def update_weight(self):
        cur_remain_workers=torch.tensor(list(range(self.worker_num)),device=self.device)
        round=0
        remain_workers={
            round:cur_remain_workers
        }
        scores=Score(self.worker_num,self.device)
        end_condition_num=self.end_condition_ratio*self.worker_num
        
        while len(cur_remain_workers)>=end_condition_num:
            cur_remain_workers=self.conduct_weight_update_round(round,cur_remain_workers,scores)
            remain_workers[round+1]=cur_remain_workers
            round=round+1
            # print("round",round,torch.sum(cur_remain_workers>=200),torch.sum(cur_remain_workers<200))
        # print(remain_workers)
        return remain_workers[round-1]
    
class SaveRemainWorkers():
    def __init__(self,save_folder,save_name="remain_workers",worker_num=None):
        self.save_folder=save_folder
        self.save_file=f"{save_folder}/{save_name}.csv"

        self.worker_num=worker_num
        header=self.get_header(worker_num)
        self.saveCsvObj=save_log.SaveCsvHeader(self.save_file,header)
        self.saveCsvObj.initialize()
        self.cur_row={}

    @staticmethod
    def get_x_name():
        return "round"
    @staticmethod
    def get_y_name():
        return "chosen"
    
    @staticmethod
    def get_y_header(worker_num):
        return [f"{SaveRemainWorkers.get_y_name()}_{j}" for j in range(worker_num)]

    @staticmethod
    def get_header(worker_num):
        h=[SaveRemainWorkers.get_x_name()]+SaveRemainWorkers.get_y_header(worker_num)
        return h
    
    def add_remain_set(self,round,remain_workers:list):
        self.cur_row["round"]=round
        for j in range(self.worker_num):
            self.cur_row[f"chosen_{j}"]=(j in remain_workers)

    def save_cur_row(self):
        if len(self.cur_row)>0:
            self.saveCsvObj.append_data([self.cur_row])
            self.cur_row={}

    def save_remain_set(self,round,remain_workers:list):
        self.add_remain_set(round,remain_workers)
        self.save_cur_row()

    
    
def get_fig(save_folder,save_name,nbworkers):
    save_file=f"{save_folder}/{save_name}.csv"
    saveCsvObj=save_log.SaveCsvHeader(save_file,None)

    # y_header=SaveRemainWorkers.get_y_header(nbworkers.worker_num)

    df=saveCsvObj.read_to_df()
    fig=worker_with_byzantine.get_values_fig(
        df,
        SaveRemainWorkers.get_y_name(),
        SaveRemainWorkers.get_x_name(),
        nbworkers)
    return fig