from .comlib import *
from .discriminator_strategy import StrategySigmoidValue
from .worker_model import WorkerModelSigmoid
from .plot import Line,LineErrorbar
# import matplotlib.pyplot as plt
# import seaborn as sns

class Discriminator():
    def __init__(self,   
                 strategy:StrategySigmoidValue,chosenWorkerNum,max_round
                 ):
        self.strategy=strategy
        self.chosenWorkerNum=chosenWorkerNum
        self.max_round=max_round

    def maximize(self,workerModel:WorkerModelSigmoid, optimizer:training.Optimizer,saveCsvObj):  
        '''
        optimizer: maximize true?
        '''   
        worker_num=workerModel.worker_num
        self.save_round(saveCsvObj,-1,workerModel)
        for round in range(self.max_round):     
            chosenWorkers=random.sample(range(worker_num), k=self.chosenWorkerNum)
            grad_tuple=self.strategy.getGrad(workerModel,chosenWorkers).tensors

            optimizer.stepByGradTuple(grad_tuple)
            if (round%10)==0 | (round==self.max_round-1):
                self.save_round(saveCsvObj,round,workerModel,grad_tuple)

    def getWorkerValues(self,workerModel):
        workers=list(range(workerModel.worker_num))
        workerValues=workerModel.getChosenWorkersSigmoidValue(workers,batch_size=400)
        return workerValues
    
    def save_round(self,saveCsvObj,round,workerModel,grad_tuple=None):
        workerValues=self.getWorkerValues(workerModel)
        weight_norm=workerModel.modelSetup.gettWeightNorm().item()
        grad_norm=None
        if grad_tuple is not None:
            grad_norm=training.ModelPara(grad_tuple).norm().item()
        saveCsvObj.appendRoundToSheet(round,weight_norm,grad_norm,workerValues)
        


class DiscriminatorSaveCsv():
    def __init__(self,save_folder,label,worker_num,initialize=True):
        self.save_folder=save_folder
        self.label=label
        self.worker_num=worker_num

        header=["round","weight_norm","grad_norm"]+list(range(self.worker_num))

        self.file_name=self.get_save_file(label)
        self.saveCsvObj=save_log.SaveCsvHeader(self.file_name,header)

        if initialize:
            self.saveCsvObj.initialize()

    
    def appendRoundToSheet(self,round,weight_norm,grad_norm,workerValues):
        row=[
            {"round":round,
             "weight_norm":weight_norm,
             "grad_norm":grad_norm}
        ]
        for i in range(self.worker_num):
            row[0][i]=workerValues[i].item()

        self.saveCsvObj.append_data(row)

    @staticmethod
    def get_sheet_name(label):
        return f"Discrimator_{label}"
    
    def get_save_file(self,label):
        return f"{self.save_folder}/{self.get_sheet_name(label)}.csv"
    
    def read_to_df(self):
        return self.saveCsvObj.read_to_df()

    @staticmethod
    def get_workers(df:pd.DataFrame):
        selected_col=df.iloc[:, 3:]
        return selected_col.columns.astype(int).tolist()
    @staticmethod
    def getSigmoidValues(df:pd.DataFrame):
        return df.iloc[:, 3:]
    @staticmethod
    def get_variance(df:pd.DataFrame):
        sigmoidValues = DiscriminatorSaveCsv.getSigmoidValues(df)
        variance_rows =sigmoidValues.var(axis=1)
        return Line(df.iloc[:,0],variance_rows)
    
    @staticmethod
    def getValuesRange(df:pd.DataFrame,chosen_workers):
        sigmoidValues = DiscriminatorSaveCsv.getSigmoidValues(df)
        line=LineErrorbar(x=df.iloc[:,0])
        line.setYFromdf(sigmoidValues.iloc[:, chosen_workers])
        return line

    @staticmethod
    def get_line(df:pd.DataFrame,col_name):
        return Line(df.iloc[:,0],df.loc[:,col_name])
    
    # @staticmethod
    # def get_weight_norm(df:pd.DataFrame):
    #     return Line(df.iloc[:,0],df.iloc[:,1])
    
    # @staticmethod
    # def get_grad_norm(df:pd.DataFrame):
    #     return Line(df.iloc[:,0],df.iloc[:,2])
    # @staticmethod
    # def plot_lines(lines:dict,save_file):
    #     flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]
    #     palette = sns.color_palette(flatui)
    #     plt.figure(figsize=(10, 6))

    #     # 绘制折线
    #     for i,(key,(x,y)) in enumerate(lines.items()):
    #         plt.plot(x,y, marker='o', label=key, color=palette[i])

    #     # 添加图例
    #     plt.legend()

    #     # 添加标题和标签
    #     plt.title('title')
    #     plt.xlabel('X-axis')
    #     plt.ylabel('Y-axis')

    #     # 显示网格
    #     plt.grid(True)

    #     # 显示图形
    #     plt.savefig(save_file)

            
    # def set_sheets(self):
    #     self.sheet_name=self.get_sheet_name(self.label)
    #     sheets={
    #         self.sheet_name:["round","weight_norm"]+list(range(self.worker_num))
    #     }
    #     self.saveCsvObj=save_log.SaveCsv(self.save_folder,sheets)

    # self.byzantine_ids=byzantine_ids
        # self.normal_ids=[x for x in range(worker_num) if x not in byzantine_ids]