import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

rolling_mean_window_size = 10
compress_ratio = 10

figure_row = 2
figure_col = 3

# Your input paths
data_paths = {
    "ResNet": ["./resnet18_cifar100"],
    "DLA": ["./dla_cifar100"],
    "EfficientNetB0": ["./efficientnet_cifar100"],
    "MobileNetv2": ["./mobilenet_cifar100"],
    "Regnetx": ["./regnet_cifar100"],
    "ShuffleNet": ["./shufflenet_cifar100"],
}

data_path_title = {
    "ResNet": "ResNet18 @CIFAR100",
    "DLA": "DLA @CIFAR100",
    "EfficientNetB0": "EfficientNet b0 @CIFAR100",
    "MobileNetv2": "MobileNet v2 @CIFAR100",
    "Regnetx": "RegNet x 200 mf @CIFAR100",
    "ShuffleNet": "ShuffleNet v2 @CIFAR100",
}

# The target filenames to collect
csv_files = [
    "training_loss.csv",
    "training_accuracy.csv",
    "weight_difference_l2.csv",
    "loss.csv",
    "accuracy.csv",
]

vector_pd = r'$|P_{i}D|$'
csv_files_title = [
    "Training Loss",
    f"L2 Distance {vector_pd}",
    "Test Loss(left) \nand Accuracy(right)",
]

# Helper function to downsample DataFrame
def downsample(df, step=10):
    return df.iloc[::step]

# Function to collect all dataframes under a list of paths
def collect_dataframes(paths):
    collected = {name: [] for name in csv_files}
    for path in paths:
        assert os.path.exists(path)
        for subfolder in os.listdir(path):
            full_subfolder_path = os.path.join(path, subfolder)
            if os.path.isdir(full_subfolder_path):
                for file_name in csv_files:
                    file_path = os.path.join(full_subfolder_path, file_name)
                    if os.path.exists(file_path):
                        print(f"loading {file_path}")
                        df = pd.read_csv(file_path)
                        collected[file_name].append(df)
    return collected

# Function to compute mean DataFrames by columns
def compute_mean_std(dfs, log_transform=False):
    if not dfs:
        return None, None
    merged = pd.concat(dfs, axis=0)
    if log_transform:
        merged = merged.copy()
        for col in merged.columns:
            if col != 'tick':
                merged[col] = merged[col].apply(lambda x: np.log10(x) if x > 0 else np.nan)
    mean_df = merged.groupby('tick').mean().rolling(window=10, min_periods=1).mean()
    std_df = merged.groupby('tick').std().rolling(window=10, min_periods=1).mean()
    return mean_df, std_df

if __name__ == "__main__":

    # Collect and compute means and stds for each group
    group_stats = {}
    for group, paths in data_paths.items():
        dfs = collect_dataframes(paths)
        group_stats[group] = {}
        for name in csv_files:
            log_transform = (name == "weight_difference_l2.csv")
            group_stats[group][name] = compute_mean_std(dfs[name], log_transform=log_transform)

    # Plotting
    row_for_one_plot = len(csv_files) - 2
    fig, axs = plt.subplots(row_for_one_plot * figure_row, figure_col, figsize=(12, 12), squeeze=False, layout="constrained")
    fig.set_constrained_layout_pads(w_pad=0.0, h_pad=0.1, wspace=0.01, hspace=0.1)
    assert figure_col*figure_row >= len(data_paths)
    plt.rcParams['axes.titlepad'] = 2
    plt.rcParams['axes.labelpad'] = 2

    idx_processed1 = set()
    idx_processed2 = set()
    for csv_file_idx, file_name in enumerate(csv_files):
        for idx, group in enumerate(data_paths.keys()):
            row_idx_of_same_group = [0,0,1,2,2][csv_file_idx]
            row_idx = idx // figure_col
            row_idx = row_idx * row_for_one_plot + row_idx_of_same_group
            col_idx = idx % figure_col


            if row_idx_of_same_group == 0:

                if idx in idx_processed1:
                    continue
                idx_processed1.add(idx)

                ax = axs[row_idx, col_idx]
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx_of_same_group])
                if row_idx_of_same_group == 0:
                    ax.set_title(data_path_title[group])

                mean_df_loss, std_df_loss = group_stats[group]["training_loss.csv"]
                mean_df_acc, std_df_acc = group_stats[group]["training_accuracy.csv"]
                if mean_df_loss is not None and std_df_loss is not None:
                    mean_df_loss = downsample(mean_df_loss, step=compress_ratio)
                    std_df_loss = downsample(std_df_loss, step=compress_ratio)
                if mean_df_acc is not None and std_df_acc is not None:
                    mean_df_acc = downsample(mean_df_acc, step=compress_ratio)
                    std_df_acc = downsample(std_df_acc, step=compress_ratio)

                if mean_df_loss is not None and std_df_loss is not None:
                    first = True
                    for column in mean_df_loss.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_loss[column]
                        lower = mean_df_loss[column] - std_df_loss[column]
                        upper = mean_df_loss[column] + std_df_loss[column]
                        print(f"print column {column}")
                        label = "training loss" if first else None
                        ax.plot(mean_df_loss.index, mean, label=label, linewidth=1, color="C0")
                        ax.fill_between(mean_df_loss.index, lower, upper, alpha=0.2, color="C0")
                ax2 = ax.twinx()
                if mean_df_acc is not None and std_df_acc is not None:
                    first = True
                    for column in mean_df_acc.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_acc[column]
                        lower = mean_df_acc[column] - std_df_acc[column]
                        upper = mean_df_acc[column] + std_df_acc[column]
                        print(f"print column {column}")
                        label = "training accuracy" if first else None
                        ax2.plot(mean_df_acc.index, mean, linestyle='dotted', label=label, linewidth=1, color="C1")
                        ax2.fill_between(mean_df_acc.index, lower, upper, alpha=0.2, color="C1")
                ax.set_xlim(mean_df_loss.index.min(), mean_df_loss.index.max())
                ax.set_xlabel(r'iteration $i$')
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
                h1, l1 = ax.get_legend_handles_labels()
                h2, l2 = ax2.get_legend_handles_labels()
                
                ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                if idx == 0:
                    ax2.set_ylim(top=1, bottom=0.995)
                    ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                elif idx == 1:
                    ax2.set_ylim(top=1, bottom=0.994)
                    ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                elif idx == 2:
                    ax2.set_ylim(top=1, bottom=0.995)
                    ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                elif idx == 3:
                    ax2.set_ylim(top=1, bottom=0.9996)
                    ax.legend(h1 + h2, l1 + l2, loc='upper left', fontsize="8")
                elif idx == 4:
                    ax2.set_ylim(top=1, bottom=0.995)
                    ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                elif idx == 5:
                    ax2.set_ylim(top=1, bottom=0.995)
                    ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                else:
                    raise NotImplementedError
            elif row_idx_of_same_group == 1:
                ax = axs[row_idx, col_idx]
                # set yscale to log for L2 distance
                if row_idx_of_same_group == 1:
                    ax.set_yscale("log")
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx_of_same_group])
                mean_df, std_df = group_stats[group][file_name]
                if mean_df is not None and std_df is not None:
                    mean_df = downsample(mean_df, step=compress_ratio)
                    std_df = downsample(std_df, step=compress_ratio)
                if mean_df is not None:
                    for column in mean_df.columns:
                        if ("bn" in column or "running_mean" in column or "running_var" in column or "num_batches_tracked" in column):
                            print(f"{column} skipped")
                            continue
                        if column != 'tick':
                            if file_name == "weight_difference_l2.csv":
                                mean = 10 ** mean_df[column]
                                lower = 10 ** (mean_df[column] - std_df[column])
                                upper = 10 ** (mean_df[column] + std_df[column])
                                ax.set_yscale('log')
                            else:
                                mean = mean_df[column]
                                lower = mean_df[column] - std_df[column]
                                upper = mean_df[column] + std_df[column]
                            ax.plot(mean_df.index, mean, label="training loss" if csv_file_idx==0 else column, linewidth=1)
                            ax.fill_between(mean_df.index, lower, upper, alpha=0.2)
                ax.set_xlabel(r'iteration $i$')
                ax.set_xlim(mean_df.index.min(), mean_df.index.max())
                ax.grid(True)
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
            elif row_idx_of_same_group == 2: # test loss and accuracy
                if idx in idx_processed2:
                    continue
                idx_processed2.add(idx)

                ax = axs[row_idx, col_idx]
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx_of_same_group])

                mean_df_loss, std_df_loss = group_stats[group]["loss.csv"]
                mean_df_acc, std_df_acc = group_stats[group]["accuracy.csv"]
                if mean_df_loss is not None and std_df_loss is not None:
                    mean_df_loss = downsample(mean_df_loss, step=compress_ratio)
                    std_df_loss = downsample(std_df_loss, step=compress_ratio)
                if mean_df_acc is not None and std_df_acc is not None:
                    mean_df_acc = downsample(mean_df_acc, step=compress_ratio)
                    std_df_acc = downsample(std_df_acc, step=compress_ratio)

                if mean_df_loss is not None and std_df_loss is not None:
                    first = True
                    for column in mean_df_loss.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_loss[column]
                        lower = mean_df_loss[column] - std_df_loss[column]
                        upper = mean_df_loss[column] + std_df_loss[column]
                        print(f"print column {column}")
                        label = "test loss" if first else None
                        ax.plot(mean_df_loss.index, mean, label=label, linewidth=1, color="C0")
                        ax.fill_between(mean_df_loss.index, lower, upper, alpha=0.2, color="C0")
                ax2 = ax.twinx()
                if mean_df_acc is not None and std_df_acc is not None:
                    first = True
                    for column in mean_df_acc.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_acc[column]
                        lower = mean_df_acc[column] - std_df_acc[column]
                        upper = mean_df_acc[column] + std_df_acc[column]
                        print(f"print column {column}")
                        label = "test accuracy" if first else None
                        ax2.plot(mean_df_acc.index, mean, linestyle='dotted', label=label, linewidth=1, color="C1")
                        ax2.fill_between(mean_df_acc.index, lower, upper, alpha=0.2, color="C1")
                ax.set_xlim(mean_df_loss.index.min(), mean_df_loss.index.max())
                ax.set_xlabel(r'iteration $i$')
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
                h1, l1 = ax.get_legend_handles_labels()
                h2, l2 = ax2.get_legend_handles_labels()
                if idx == 0:
                    # ax.set_ylim(top=0.1, bottom=0)
                    ax2.set_ylim(bottom=0.65)
                    ax.legend(h1 + h2, l1 + l2, loc='center', fontsize="8")
                if idx == 1:
                    # ax.set_ylim(top=0.85)
                    ax2.set_ylim(bottom=0.65)
                    ax.legend(h1 + h2, l1 + l2, loc='lower left', fontsize="8")
                if idx == 2:
                    # ax.set_ylim(top=0.65)
                    ax2.set_ylim(bottom=0.58)
                    ax.legend(h1 + h2, l1 + l2, loc='lower left', fontsize="8")
                if idx == 3:
                    # ax.set_ylim(top=25)
                    # ax2.set_ylim(top=0.95)
                    ax.legend(h1 + h2, l1 + l2, loc='lower left', fontsize="8")
                if idx == 4:
                    # ax.set_ylim(top=0.55)
                    # ax2.set_ylim(top=0.97)
                    ax.legend(h1 + h2, l1 + l2, loc='upper center', fontsize="8")
                if idx == 5:
                    # ax.set_ylim(top=0.65)
                    # ax2.set_ylim(top=0.945)
                    ax.legend(h1 + h2, l1 + l2, loc='upper center', fontsize="8")
                ax.grid(True)


    # plt.tight_layout(pad=0.5)
    plt.savefig("m2o_appendix_cifar100.pdf", pad_inches=0)
