# import sys
# sys.path.append('/home/yjf/FL/robustfl')

from .comlib import *

class Discriminator():
    def __init__(self, modelSetup: training.ModelSetup, workerDataset:dataset.AggWorkersDatasetFromConf, save_file,label=''):
        '''
        model input: (data,target)
        '''
        self.device=modelSetup.device
        self.modelSetup=modelSetup
        self.workerDataset=workerDataset
        self.sigmoid=nn.Sigmoid()

        self.save_file=save_file
        self.label=label
        self.set_sheets()

    
    # def scatterValue(self,out,worker_ids,worker_num):
    #     loss_data_vec=self.sigmoid(out).view(-1)
    #     loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=self.device)
    #     loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

    #     return loss_worker_vec
    
    # def getScatterFunc(self,worker_num):
    #     def func(out,worker_ids):
    #         return self.scatterValue(out,worker_ids,worker_num)
    #     return func

    def getVarFunc(self,worker_num):
        def func(out,worker_ids):
            values=self.scatterValue(out,worker_ids,worker_num)
            var=torch.var(values, dim=None, correction=0, keepdim=False, out=None)
            return var
        return func
    
    def calcu_variance_grad(self,worker_num, worker_ids, data, target):
        input_args=(data,target)
        output_args=(worker_ids,)
        jacobian_tuple=self.modelSetup.calcuJacobian(input_args,self.getVarFunc(worker_num),output_args)
        return jacobian_tuple
        

    def maximize_variance(self,max_epoch, optimizer:training.Optimizer,chosenWorkerNum):     
        # self.model.train()
        for epoch in range(max_epoch):     
            # chosenWorkers=random.sample(range(worker_num))
            chosenWorkerDataset=self.workerDataset.randomSubWokerDataset(chosenWorkerNum)
            dataloader=DataLoader(chosenWorkerDataset,batch_size=400,shuffle=False)

            for batch_idx, (worker_ids, data, target) in enumerate(dataloader):                
                jacobian_tuple=self.calcu_variance_grad(chosenWorkerNum, worker_ids, data, target)
                optimizer.stepByGradTuple(jacobian_tuple)
            if (epoch%10)==0:
                self.getCurStat(epoch)


    def getCurStat(self,epoch):
        worker_num=self.workerDataset.worker_num
        dataloader=DataLoader(self.workerDataset,batch_size=400,shuffle=False)
        workerValues=torch.zeros(worker_num,device=self.device)
        for batch_idx, (worker_ids, data, target) in enumerate(dataloader):
            input_args=(data,target)
            output_args=(worker_ids,)
            workerValues=workerValues+self.modelSetup.calcuLoss(input_args,self.getScatterFunc(worker_num),output_args)
        workerValues=workerValues.cpu()
        workerValues=workerValues/torch.tensor(self.workerDataset.dataNumPerWorker)
        workerValuesList=list(workerValues)
        

        # variance=torch.var(workerValues, dim=None, correction=0, keepdim=False, out=None)
        std=torch.std(workerValues, dim=None, correction=0, keepdim=False, out=None)
        # print(std**2, variance)


        self.saveExcelObj.append_data("workerValue",pd.DataFrame([[epoch,std.item(),torch.mean(workerValues[:50]),torch.mean(workerValues[50:])]+workerValuesList], columns=self.saveExcelObj.sheets["workerValue"]))
        print([epoch,std.item(),torch.mean(workerValues[:50]),torch.mean(workerValues[50:])])
    


            # self.saveExcelObj.append_data("variance",pd.DataFrame([[epoch,std.item()]], columns=self.saveExcelObj.sheets["variance"]))

        
    def set_sheets(self):
        sheets={
            # "variance":["epoch","std"],
            "workerValue":["epoch","std","normal","ad"]+list(range(self.workerDataset.worker_num))
        }
        self.saveExcelObj=save_log.SaveExcel(f"{self.save_file}",sheets)

    # @classmethod
    # def draw_variance_maximization_process(self,table_name,save_file):
    #     sns.set_theme(style="whitegrid")
    #     flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]
    #     sns.set_palette(flatui)

    #     criterions=["discrete_values"]
    #     for criterion in criterions:
    #         self.plot_worker_values(table_name,criterion,save_file)
    #         self.plot_worker_variance(table_name,criterion,save_file)
    #     self.plot_criterion(table_name,"loss",save_file)
            

    # @classmethod
    # def plot_worker_values(self,table_name,criterion,save_file):
    #     plt.figure(figsize=(10, 6))
    #     statement=f"""SELECT * FROM {table_name} 
    #     WHERE criterion=?
    #     """

    #     df=db_util.select_data_from_table_to_df(DATABASEFILE, statement, params=(criterion,))
    #     filtered_data1 = df[df['worker_name'] <10]
    #     filtered_data2 = df[df['worker_name'] >=10]
    #     sns.lineplot(data=filtered_data1, x='epoch', y='value', errorbar="sd")
    #     sns.lineplot(data=filtered_data2, x='epoch', y='value', errorbar="sd")

    #     # mean_per_epoch = filtered_data1.groupby('epoch')['value'].mean().to_frame()

    #     # sns.lineplot(data=df, x='epoch', y='value', hue='worker_name')
    #     LOGGER.debug(df)

    #     plt.legend()
    #     plt.xlabel("epoch")
    #     plt.ylabel("value")
    #     plt.title(f"criterion:{criterion}")
    #     plt.savefig(save_file+"_workers_"+str(criterion))

    # @classmethod
    # def plot_worker_variance(self,table_name,criterion,save_file):
    #     plt.figure(figsize=(10, 6))
    #     statement=f"""SELECT * FROM {table_name} 
    #     WHERE criterion=?
    #     """

    #     df=db_util.select_data_from_table_to_df(DATABASEFILE, statement, params=(criterion,))

    #     variance_by_time = df.groupby('epoch')['value'].var().to_frame()
    #     LOGGER.debug(variance_by_time)

    #     sns.lineplot(x="epoch", y="value", data=variance_by_time)
    #     # LOGGER.debug(df)

    #     # plt.legend()
    #     plt.xlabel("epoch")
    #     plt.ylabel("value")
    #     plt.title(f"criterion:{criterion}")
    #     plt.savefig(save_file+"_variance_"+str(criterion))

    # @classmethod
    # def plot_criterion(self,table_name,criterion,save_file):
    #     plt.figure(figsize=(10, 6))
    #     statement=f"""SELECT * FROM {table_name} 
    #     WHERE criterion=?
    #     """

    #     df=db_util.select_data_from_table_to_df(DATABASEFILE, statement, params=(criterion,))

    #     sns.lineplot(x="epoch", y="value", data=df)
    #     LOGGER.debug(df)

    #     # plt.legend()
    #     plt.xlabel("epoch")
    #     plt.ylabel("value")
    #     plt.title(f"criterion:{criterion}")
    #     plt.savefig(save_file+"_"+str(criterion))

