import csv
import pandas as pd
import random
import numpy as np
import matplotlib.pyplot as plt

DATASET_PATH = '../datasets'
PDF_PATH = '../pdfs'
linestyles = ['-', '--', '-.', ':']

def preprocess_csv(csv_path):
    csv_file = pd.read_csv(csv_path, header=0)
    csv_file.columns = csv_file.columns.str.replace(' ', '_')
    csv_file.columns = csv_file.columns.str.replace('"', '')
    csv_file.columns = csv_file.columns.str.strip()
    csv_file.columns = csv_file.columns.str.replace('-', '')
    csv_file.columns = csv_file.columns.str.replace('Exp1_', '')
    csv_file.columns = csv_file.columns.str.replace('Exp0_', '')
    csv_file.columns = csv_file.columns.str.replace('N20_C500_UVTrue_', '')
    csv_file.columns = csv_file.columns.str.replace('S0_P0_D1000_T0.8_O1.0_0__Test/', '')
    csv_file.columns = csv_file.columns.str.replace('S0_P1_D1000_T0.8_O1.0_0__Test/', '')
    csv_file.columns = csv_file.columns.str.replace('S0_P0_D100_T0.8_O1.0_0__Test/', '')
    csv_file.columns = csv_file.columns.str.replace('S0_P0_D100_T0.8_O1.3_0__Test/', '')

    csv_file['Step'] = csv_file['Step'].astype(int)
    columns = csv_file.columns[1:]
    for column in columns:
        csv_file[column] = csv_file[column].astype(float)
    csv_file.to_csv(csv_path, index=False)

def plot_fig(csv_path, system_names, label_names, dataset_name, model_name, aggregator_type):
    df = pd.read_csv(csv_path, delimiter=',', header=0)

    fig = plt.figure(figsize=(3, 2.5), dpi=120)
    ax1 = fig.add_subplot(1, 1, 1)
    # 设置x,y轴数字标注的间隔
    ax1.yaxis.set_major_locator(plt.MultipleLocator(10))
    ax1.xaxis.set_major_locator(plt.MultipleLocator(200))

    for i in range(len(system_names)):
        system_value = df[f'{system_names[i]}_UTincrease_efficiency_{aggregator_type}_acc_top_1'] # UTincrease_train_acc_mulitply_loss
        # smooth the curve
        # 取出window中的最大值
        system_value = system_value.rolling(window=5).mean()    
        ax1.plot(df['Step'], system_value, label=label_names[i], color=plt.cm.tab20(i / len(system_names)), linestyle=linestyles[i % len(linestyles)])

    ax1.set_xlabel('Training Rounds')
    ax1.set_ylabel('Top1 Accuracy(%)')
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False) 
    ax1.grid(True, linestyle='-.', alpha=0.3, axis='y')
    ax1.legend(loc='best', ncol=2, frameon=False, columnspacing=0.5, labelspacing=0.5, handlelength=1.2, handletextpad=0.5, borderaxespad=0.5, borderpad=0.5)
    plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1)
    plt.savefig(f'{PDF_PATH}/{dataset_name}_{model_name}_{aggregator_type}_top1_accuracy.pdf', bbox_inches='tight') 


dataset_names = ['cifar10', 'cifar10', 'cifar100', 'femnist']
#dataset_names = ['googlespeech']
model_names = ['mobilenet', 'resnet18', 'shufflenet', 'resnet18']
#model_names = ['resnet34']
aggregator_types = ['FedAvg', 'YoGi']
system_names = ['random', 'safa', 'refl', 'oort', 'suv']
label_names = ['FedAvg', 'SAFA', 'REFL', 'Oort', 'FedSUV']
for i in range(len(dataset_names)):
    model_name = model_names[i]
    dataset_name = dataset_names[i]
    for aggregator_type in aggregator_types:
        csv_path = f'{DATASET_PATH}/{dataset_name}_{model_name}_{aggregator_type}.csv'
        preprocess_csv(csv_path)
        plot_fig(csv_path, system_names, label_names, dataset_name, model_name, aggregator_type)