import pandas as pd
import numpy as np
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
from matplotlib.ticker import MaxNLocator

DATASET_PATH = '../../simulator/simDataset'
PDF_PATH = '../pdfs'

def preprocess_csv(csv_path):
    df = pd.read_csv(csv_path, header=0)
    # 只保留列名为utility_mu, valid_mu, select_round的列
    df = df[['utility_mu', 'valid_mu', 'select_round']]
    df['utility_mu'] = round(df['utility_mu'].astype(float),2)
    df['valid_mu'] = round(df['valid_mu'].astype(float),2)
    # 如果utility_mu和valid_mu的值同时相同，则取select_round的最大值
    df = df.groupby(['utility_mu', 'valid_mu']).agg({'select_round': 'max'}).reset_index()
    #df['select_round'] = df['select_round'].replace(0, 100)
    # 删除select_round列中值为0的行
    #df = df[df['select_round'] != 0]
    df['select_round'] = df['select_round'].astype(int)
    df.to_csv(f"{DATASET_PATH}/clients_vu_for_sim_select_process.csv", index=False)

def plot_heatmap(csv_path):
    df = pd.read_csv(csv_path, delimiter=',', header=0)
    # 输出utility_mu和valid_mu的不重复值的个数
    print(df['utility_mu'].nunique(), df['valid_mu'].nunique())
    print(df.head())
    
    pivot_table = df.pivot(index='valid_mu', columns='utility_mu', values='select_round')
    
    fig, ax = plt.subplots(figsize=(4, 3), dpi=120)
    sns.heatmap(pivot_table, cmap='Blues', ax=ax, cbar_kws={'label': 'Select Round'},square=True)  #coolwarm

    colorbar = ax.collections[0].colorbar
    colorbar.ax.tick_params(labelsize=13)  # 设置 colorbar 字体大小
    colorbar.set_label('Number of rounds', fontsize=13)  # 这里设置颜色条标签的字体大小
    colorbar.ax.yaxis.set_major_locator(MultipleLocator(10))

    ax.invert_yaxis()
    ax.set_ylabel('Validity', fontsize=13)
    ax.set_xlabel('Utility', fontsize=13)
    ax.tick_params(axis='both', which='major', labelsize=13)  # 设置主刻度标签的字体大小

    plt.xticks(ticks=[i*5 for i in range(0,11)],labels=[i/10 for i in range(0,11)]) # 修改刻度
    plt.yticks(ticks=[i*5 for i in range(0,11)],labels=[i/10 for i in range(0,11)]) # 修改刻度
    
    plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=.0)
    plt.savefig(f'{PDF_PATH}/heatmap_select.pdf', bbox_inches='tight')

csv_path = f'{DATASET_PATH}/clients_vu_for_sim_select.csv'
preprocess_csv(csv_path)
csv_path = f'{DATASET_PATH}/clients_vu_for_sim_select_process.csv'
plot_heatmap(csv_path)

    