from .comlib import *
from .discriminator_ import Discriminator,DiscriminatorSaveCsv
from .max_variance import MaxVariance
from .discriminator_strategy_factory import create_strategy
from .worker_model import WorkerModelSigmoid

class Score():
    def __init__(self,worker_num,device):
        self.worker_num=worker_num
        self.device=device
        self.scores=torch.zeros(self.worker_num,device=self.device)

    def get_scores(self,workers:torch.tensor):
        return self.scores[workers]
    def update(self,workers:torch.tensor,add_value):
        self.scores[workers]=self.scores[workers]+add_value


class WeightUpdater():
    def __init__(self,workerModel:WorkerModelSigmoid, discriminator:Discriminator,optimizer:training.Optimizer,
                 save_folder,label='',eta=1):
        self.workerModel=workerModel
        self.save_folder=save_folder
        self.label=label
        self.eta=eta

        self.discriminator=discriminator
        self.value_batch_size=self.discriminator.strategy.value_batch_size

        self.optimizer=optimizer

        self.device=workerModel.device
        self.worker_num=workerModel.worker_num
        self.remain_workers={
            0:list(range(self.worker_num))
        }
        self.cur_round=0
        self.scores=Score(self.worker_num,self.device)
            
    # def maximizeAndGetValue(kernel):
    #     # maximize_var
    #     max_eigenvalue, max_eigenvector=util.Kernel.get_top_eigen(kernel,k=1)
    #     values=util.Kernel.get_projection(kernel,max_eigenvector)[:,0]
    #     return values

    def getRemainWorkersFromValues(self,cur_remain_workers,values):
        max_value=torch.max(values)
        nomalized_value=values/max_value
        self.scores.update(cur_remain_workers,nomalized_value)
        remain_workers_score=self.scores.get_scores(cur_remain_workers)
        return cur_remain_workers[remain_workers_score>self.eta]
    
    def conduct_weight_update_round(self,round):
        label=str(round)
        saveObj=DiscriminatorSaveCsv(self.save_folder,label,self.worker_num)
        cur_remain_workers=self.remain_workers[round]
        remainWorkerNum=len(cur_remain_workers)
        remainWorkerDataset=self.workerModel.getSubWorkerDataset(cur_remain_workers)

        self.discriminator.maximize(remainWorkerDataset,self.optimizer,saveObj)

        values=self.workerModel.getChosenWorkersSigmoidValue(list(range(remainWorkerNum)),self.value_batch_size)

        self.remain_workers[round+1]=self.getRemainWorkersFromValues(cur_remain_workers,values)
    
    def update_weight(self,update_weight_round,optimizer_initialize:bool):
        for round in update_weight_round:
            self.cur_round=round
            if optimizer_initialize:
                self.optimizer.initialize_optimizer()
            self.conduct_weight_update_round(round)
