import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

rolling_mean_window_size = 10
compress_ratio = 25

figure_row = 3
figure_col = 3

# Your input paths
data_paths = {
    "Lenet": ["./lenet5_mnist"],
    "VGG": ["./vgg11_bn_cifar10"],
    "DenseNet": ["./densenet_cifar10"],
    "EfficientNetB0": ["./efficientnet_b0_cifar10"],
    "MobileNetv2": ["./mobilenetv2_cifar10"],
    "Regnetx": ["./regnet_x_200_mf_cifar10"],
    "ShuffleNet": ["./shufflenet_v2_cifar10"],
    # "ResNet18": ["./resnet18_cifar100"],
    # "DLA": ["./dla_cifar100"],
}

data_path_title = {
    "Lenet": "LeNet5 @MNIST",
    "VGG": "VGG11 @CIFAR10",
    "DenseNet": "DenseNet @CIFAR10",
    "EfficientNetB0": "EfficientNet b0 @CIFAR10",
    "MobileNetv2": "MobileNet v2 @CIFAR10",
    "Regnetx": "RegNet x 200 mf @CIFAR10",
    "ShuffleNet": "ShuffleNet v2 @CIFAR10",
    "ResNet18": "ResNet18 @CIFAR100",
    "DLA": "DLA @CIFAR100",
}

# The target filenames to collect
csv_files = [
    "training_loss.csv",
    "weight_difference_l2.csv",
    "loss.csv",
    "accuracy.csv"
]

training_accuracy_csv_paths = {
    "Lenet": "./training_accuracy_file/lenet.csv",
    "VGG": "./training_accuracy_file/vgg11.csv",
    "DenseNet": "./training_accuracy_file/densenet.csv",
    "EfficientNetB0": "./training_accuracy_file/efficientnet_b0.csv",
    "MobileNetv2": "./training_accuracy_file/mobilenetv2.csv",
    "Regnetx": "./training_accuracy_file/regnet_x_200mf.csv",
    "ShuffleNet": "./training_accuracy_file/shufflenet_v2.csv",
    "ResNet18": "./training_accuracy_file/resnet18_cifar100.csv",
    "DLA": "./training_accuracy_file/dla_cifar100.csv",
}

vector_pd = r'$|P_{i}D|$'
csv_files_title = [
    "Training Loss(left) \n and Accuracy(right)",
    f"L2 Distance {vector_pd}",
    "Test Loss(left) \n and Accuracy(right)",
]

# Helper function to downsample DataFrame
def downsample(df, step=10):
    return df.iloc[::step]

# Function to collect all dataframes under a list of paths
def collect_dataframes(paths):
    collected = {name: [] for name in csv_files}
    for path in paths:
        assert os.path.exists(path)
        for subfolder in os.listdir(path):
            full_subfolder_path = os.path.join(path, subfolder)
            if os.path.isdir(full_subfolder_path):
                for file_name in csv_files:
                    file_path = os.path.join(full_subfolder_path, file_name)
                    if os.path.exists(file_path):
                        print(f"loading {file_path}")
                        df = pd.read_csv(file_path)
                        collected[file_name].append(df)
    return collected

# Function to compute mean DataFrames by columns
def compute_mean_std(dfs, log_transform=False):
    if not dfs:
        return None, None
    merged = pd.concat(dfs, axis=0)
    if log_transform:
        merged = merged.copy()
        for col in merged.columns:
            if col != 'tick':
                merged[col] = merged[col].apply(lambda x: np.log10(x) if x > 0 else np.nan)
    mean_df = merged.groupby('tick').mean().rolling(window=10, min_periods=1).mean()
    std_df = merged.groupby('tick').std().rolling(window=10, min_periods=1).mean()
    return mean_df, std_df

if __name__ == "__main__":

    # Collect and compute means and stds for each group
    group_stats = {}
    for group, paths in data_paths.items():
        dfs = collect_dataframes(paths)
        group_stats[group] = {}
        for name in csv_files:
            log_transform = (name == "weight_difference_l2.csv")
            group_stats[group][name] = compute_mean_std(dfs[name], log_transform=log_transform)

    # Plotting
    row_for_one_plot = len(csv_files) - 1
    fig, axs = plt.subplots(row_for_one_plot * figure_row, figure_col, figsize=(12, 16), squeeze=False, layout="constrained")
    fig.set_constrained_layout_pads(w_pad=0.0, h_pad=0.1, wspace=0.01, hspace=0.1)
    assert figure_col*figure_row >= len(data_paths)
    plt.rcParams['axes.titlepad'] = 2
    plt.rcParams['axes.labelpad'] = 2

    idx_processed = set()
    for csv_file_idx, file_name in enumerate(csv_files):
        for idx, group in enumerate(data_paths.keys()):
            row_idx_of_same_group = [0,1,2,2][csv_file_idx]
            row_idx = idx // figure_col
            row_idx = row_idx * row_for_one_plot + row_idx_of_same_group
            col_idx = idx % figure_col

            if row_idx_of_same_group != 2:
                ax = axs[row_idx, col_idx]
                # set yscale to log for L2 distance
                if row_idx_of_same_group == 1:
                    ax.set_yscale("log")
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx_of_same_group])
                mean_df, std_df = group_stats[group][file_name]
                if mean_df is not None and std_df is not None:
                    mean_df = downsample(mean_df, step=compress_ratio)
                    std_df = downsample(std_df, step=compress_ratio)
                if mean_df is not None:
                    for column in mean_df.columns:
                        if ("bn" in column or "running_mean" in column or "running_var" in column or "num_batches_tracked" in column):
                            print(f"{column} skipped")
                            continue
                        if column != 'tick':
                            if file_name == "weight_difference_l2.csv":
                                mean = 10 ** mean_df[column]
                                lower = 10 ** (mean_df[column] - std_df[column])
                                upper = 10 ** (mean_df[column] + std_df[column])
                                ax.set_yscale('log')
                            else:
                                mean = mean_df[column]
                                lower = mean_df[column] - std_df[column]
                                upper = mean_df[column] + std_df[column]
                            ax.plot(mean_df.index, mean, label="training loss" if csv_file_idx==0 else column, linewidth=1)
                            ax.fill_between(mean_df.index, lower, upper, alpha=0.2)
                ax.set_xlabel(r'iteration $i$')
                ax.set_xlim(mean_df.index.min(), mean_df.index.max())
                ax.grid(True)
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
                if row_idx_of_same_group == 0:
                    ax.set_title(data_path_title[group])
                    training_accuracy_csv = training_accuracy_csv_paths[group]
                    training_accuracy_df = pd.read_csv(training_accuracy_csv)
                    training_accuracy_df = downsample(training_accuracy_df, step=compress_ratio)
                    training_accuracy_df = training_accuracy_df.groupby('tick').mean().rolling(window=10, min_periods=1).mean()
                    ax2 = ax.twinx()
                    ax2.plot(training_accuracy_df.index, training_accuracy_df["0"], linestyle='dotted', label="training accuracy", linewidth=1, color="C1")
                    h1, l1 = ax.get_legend_handles_labels()
                    h2, l2 = ax2.get_legend_handles_labels()
                    
                    ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                    ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                    if idx == 0:
                        ax2.set_ylim(bottom=0.97)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 1:
                        # ax2.set_ylim(top=1.001)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 2:
                        # ax2.set_ylim(top=1.001)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 3:
                        # ax2.set_ylim(top=1.001)
                        ax.legend(h1 + h2, l1 + l2, loc='center right', fontsize="8")
                    elif idx == 4:
                        ax2.set_ylim(top=1, bottom=0.9997)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 5:
                        ax2.set_ylim(bottom=0.9998)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 6:
                        # ax2.set_ylim(top=)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 7:
                        # ax2.set_ylim(top=1.001)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    elif idx == 8:
                        ax2.set_ylim(bottom=0.9996)
                        ax.legend(h1 + h2, l1 + l2, loc='lower right', fontsize="8")
                    else:
                        raise NotImplementedError

            elif row_idx_of_same_group == 2: # test loss and accuracy
                if idx in idx_processed:
                    continue
                idx_processed.add(idx)

                ax = axs[row_idx, col_idx]
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx_of_same_group])

                mean_df_loss, std_df_loss = group_stats[group]["loss.csv"]
                mean_df_acc, std_df_acc = group_stats[group]["accuracy.csv"]
                if mean_df_loss is not None and std_df_loss is not None:
                    mean_df_loss = downsample(mean_df_loss, step=compress_ratio)
                    std_df_loss = downsample(std_df_loss, step=compress_ratio)
                if mean_df_acc is not None and std_df_acc is not None:
                    mean_df_acc = downsample(mean_df_acc, step=compress_ratio)
                    std_df_acc = downsample(std_df_acc, step=compress_ratio)

                if mean_df_loss is not None and std_df_loss is not None:
                    first = True
                    for column in mean_df_loss.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_loss[column]
                        lower = mean_df_loss[column] - std_df_loss[column]
                        upper = mean_df_loss[column] + std_df_loss[column]
                        print(f"print column {column}")
                        label = "test loss" if first else None
                        ax.plot(mean_df_loss.index, mean, label=label, linewidth=1, color="C0")
                        ax.fill_between(mean_df_loss.index, lower, upper, alpha=0.2, color="C0")
                ax2 = ax.twinx()
                if mean_df_acc is not None and std_df_acc is not None:
                    first = True
                    for column in mean_df_acc.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_acc[column]
                        lower = mean_df_acc[column] - std_df_acc[column]
                        upper = mean_df_acc[column] + std_df_acc[column]
                        print(f"print column {column}")
                        label = "test accuracy" if first else None
                        ax2.plot(mean_df_acc.index, mean, linestyle='dotted', label=label, linewidth=1, color="C1")
                        ax2.fill_between(mean_df_acc.index, lower, upper, alpha=0.2, color="C1")
                ax.set_xlim(mean_df_loss.index.min(), mean_df_loss.index.max())
                ax.set_xlabel(r'iteration $i$')
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
                h1, l1 = ax.get_legend_handles_labels()
                h2, l2 = ax2.get_legend_handles_labels()
                if idx == 0:
                    ax.set_ylim(top=0.1)
                    ax2.set_ylim(top=1.005)
                    ax.legend(h1 + h2, l1 + l2, loc='upper left', fontsize="8")
                if idx == 1:
                    ax.set_ylim(top=0.5)
                    ax2.set_ylim(top=0.907)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 2:
                    # ax.set_ylim(top=0.65)
                    ax2.set_ylim(top=0.936)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 3:
                    ax.set_ylim(top=7)
                    ax2.set_ylim(top=0.94, bottom=0.89)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 4:
                    # ax.set_ylim(top=0.55)
                    ax2.set_ylim(top=0.952, bottom=0.942)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 5:
                    # ax.set_ylim(top=0.65)
                    ax2.set_ylim(top=0.935)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 6:
                    # ax.set_ylim(top=0.6)
                    ax2.set_ylim(top=0.905)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 7:
                    # ax.set_ylim(top=0.6)
                    ax2.set_ylim(top=0.78)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                if idx == 8:
                #     ax.set_ylim(top=0.6)
                #     ax2.set_ylim(top=0.94)
                    ax.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                ax.grid(True)

    total_slots = figure_row * figure_col  # e.g. 9
    used_slots = len(data_paths)           # e.g. 7
    for slot_idx in range(used_slots, total_slots):
        base_row = slot_idx // figure_col  # architecture row index (0..figure_row-1)
        col_idx = slot_idx % figure_col    # architecture column index (0..figure_col-1)
        for sub in range(row_for_one_plot):
            ax = axs[base_row * row_for_one_plot + sub, col_idx]
            ax.set_visible(False)

    # fig.tight_layout(w_pad=0, h_pad=0)
    fig.savefig("m2o_appendix.jpg", pad_inches=0, dpi=400)
    fig.savefig("m2o_appendix.pdf", pad_inches=0)
