import pandas as pd 
import matplotlib
matplotlib.use('Agg')  # Use non-GUI backend
from matplotlib import pyplot as plt
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--network_typ', type=str, default='circle')
parser.add_argument('--attack_typ', type=str, default='featureatt')
args = parser.parse_args()
## Set hyperparameters
n_workers = 50
data_name = 'mnist'
#data_name = 'cifar'
network_typ = args.network_typ
attack_typ = args.attack_typ
q_degrees = [0.06,0.6] if network_typ == 'circle' else [0.2,0.6]
byz_ratios = [0.15,0.25,0.35]
random_state = 2025
output_dir = f'output_{data_name}'  # Output path for accuracy/loss, etc.
#methods = ['init','adfl','bgmed','bgtrimmed','clipped','gtmed','gttrim']
methods = ['init','adfl','bgmed','bgtrimmed','clipped'] if network_typ == 'circle' else ['init','adfl','bgmed','bgtrimmed','clipped','gtmed','gttrim']
#methods_name = ['DFL','aDFL','BRIDGE-M','BRIDGE-T','ClippedGossip','SLBRN-M','SLBRN-T']
methods_name = ['DFL','aDFL','BRIDGE-M','BRIDGE-T','ClippedGossip'] if network_typ == 'circle' else  ['DFL','aDFL','BRIDGE-M','BRIDGE-T','ClippedGossip','SLBRN-M','SLBRN-T']
if data_name == 'cifar':
    metric = 'valacc'
else:
    #metric = 'valacc'
    metric = 'valloss'

K = 100  # Evaluate every K iterations

data = []

for byz_ratio in byz_ratios:
    df_single = f'{output_dir}/metric_single_{byz_ratio}.csv'
    oracle = pd.read_csv(df_single)[metric[3:]][0]
    for q_degree in q_degrees:
        best_param_path = f'{output_dir}/bestparam_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.csv'
        best_param = pd.read_csv(best_param_path)

        dfs = []
        for method in methods:
            output_path = f'{output_dir}/metric_{method}_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}_{metric}.csv'
            df = pd.read_csv(output_path)
            df_new = df.iloc[:, ::K]
            df_false = df_new[~best_param['byz_labels']]
            dfs.append(df_false)
        data.append(dfs)


def fix_df(df):
    for i in range(2, df.shape[1]):
        # If the current value > 2 × previous value, replace it with previous value
        df.iloc[:, i] = np.where(
            df.iloc[:, i] > 2 * df.iloc[:, i - 1],
            df.iloc[:, i - 1],
            df.iloc[:, i]
        )
    return df

def process_data(data):
    for row in data:
        for idx, df in enumerate(row):
            row[idx] = fix_df(df)
    return data

# Disable chained assignment warning
pd.options.mode.chained_assignment = None
data = process_data(data)

epochs = np.arange(0, data[0][0].shape[1])

# Plotting function
def plot_lines_with_confidence_intervals(ax, epochs, dfs, methods, data_name='mnist',
                                         linewidth=1, oracle=None, title=None,
                                         show_ylabel=False, show_xlabel=False):
    """
    ax           : matplotlib axis object
    epochs       : list of x-axis values (iterations or epochs)
    dfs          : list of DataFrames or numpy arrays for each method
    methods      : list of method names
    data_name    : 'mnist' or 'cifar'
    linewidth    : line width
    oracle       : value of oracle for horizontal reference line
    title        : subplot title
    show_ylabel  : whether to show y-axis label
    show_xlabel  : whether to show x-axis label
    """

    # Define custom linestyles
    linestyles = ["-", "--", "-.", (0, (3, 1, 1, 1)), (0, (5, 5)), (0, (1, 1)),
                  (0, (4, 2, 1, 2))]

    for idx, (df, method) in enumerate(zip(dfs, methods)):
        style = linestyles[idx % len(linestyles)]
        data = df.to_numpy()
        if data_name == 'cifar':
            data = data * 100

        mean_values = np.mean(data, axis=0)
        std_values = np.std(data, axis=0)
        ci_upper = mean_values + 1.96 * std_values / np.sqrt(data.shape[0])
        ci_lower = mean_values - 1.96 * std_values / np.sqrt(data.shape[0])

        ax.plot(epochs, mean_values,
                linewidth=linewidth,
                label=method,
                linestyle=style)

        ax.fill_between(epochs, ci_lower, ci_upper, alpha=0.2)

    if oracle is not None:
        if data_name == 'cifar':
            oracle = oracle * 100
        ax.axhline(y=oracle, color='black', linestyle='--',
                   linewidth=linewidth, label="Oracle")

    ax.set_title(title, fontsize=10)

    if show_ylabel:
        if data_name == 'mnist':
            ax.set_ylabel('Testing Loss')
        elif data_name == 'cifar':
            ax.set_ylabel('Testing Acc (%)')

    if show_xlabel:
        ax.set_xlabel(r'Iterations ($\times$ 100)')

    # Axis limits
    if data_name == 'mnist':
        ax.set_xlim(0, 90)
        ax.set_ylim(0, 2.0) if network_typ == 'circle' else ax.set_ylim(0, 2.0)
    else:  # cifar
        ax.set_xlim(0, 90)
        ax.set_ylim(20, 88)

    # Grid and background
    ax.grid(True, color='white')
    ax.set_facecolor('#f5f5ff')

    # Remove axis borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)

    # Remove tick lines but keep labels
    ax.tick_params(axis='both', length=0, colors='black')

fig, axes = plt.subplots(2, 3, figsize=(10, 4), sharex=True, sharey=True)

# Adjust spacing between subplots
fig.subplots_adjust(wspace=-1, hspace=-3.5, top=0.9, bottom=0)

# Subplot titles
if network_typ == 'circle':
    titles = [
        r"D = 3  | $\varrho$ = 15%",
        r"D = 3  | $\varrho$ = 25%",
        r"D = 3  | $\varrho$ = 35%",
        r"D = 30 | $\varrho$ = 15%",
        r"D = 30 | $\varrho$ = 25%",
        r"D = 30 | $\varrho$ = 35%",
    ]
else:
    titles = [
        r"q = 0.2  | $\varrho$ = 15%",
        r"q = 0.2  | $\varrho$ = 25%",
        r"q = 0.2  | $\varrho$ = 35%",
        r"q = 0.6 | $\varrho$ = 15%",
        r"q = 0.6 | $\varrho$ = 25%",
        r"q = 0.6 | $\varrho$ = 35%",
    ]

linewidth = 1.2

# Rearrange data order to match subplot titles
data2 = []
for i in range(3):
    data2.append(data[2 * i])
for i in range(3):
    data2.append(data[2 * i + 1])

# Plot each subplot
handles = None
labels = None

for i, ax in enumerate(axes.flatten()):
    plot_lines_with_confidence_intervals(
        ax,
        epochs,
        data2[i],
        methods=methods_name,
        linewidth=linewidth,
        oracle=oracle,
        title=titles[i],
        data_name=data_name,
        show_ylabel=(i % 3 == 0),
        show_xlabel=(i // 3 == 1),
    )

    if i == 0:
        handles, labels = ax.get_legend_handles_labels()

# Add shared legend outside the top of the plot
if handles and labels:
    fig.legend(handles, labels,
               loc='upper center',
               bbox_to_anchor=(0.5, 1.06),
               frameon=False,
               fontsize=7.0,
               ncol=len(labels))

# Tight layout with space for legend
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.savefig(f'{data_name}{attack_typ}{network_typ}.pdf', dpi=250, bbox_inches='tight')
#plt.show()