import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

rolling_mean_window_size = 10
compress_ratio = 100

# Your input paths
data_paths = {
    "ResNet18": ["./resnet18/start_to_end", "./resnet18/end_to_start"],
    "DLA": ["./dla"],
    "CCT7": ["./cct7"],
}

data_path_title = {
    "ResNet18": "ResNet18 @CIFAR10",
    "DLA": "DLA @CIFAR10 (FDF strategy)",
    "CCT7": "CCT7 @CIFAR10",
}

training_accuracy_csv_paths = {
    "ResNet18": "./training_accuracy_file/resnet18.csv",
    "DLA": "./training_accuracy_file/dla.csv",
    "CCT7": "./training_accuracy_file/cct7.csv",
}

# The target filenames to collect
csv_files = [
    "training_loss.csv",
    "weight_difference_l2.csv",
    "loss.csv",
    "accuracy.csv"
]

vector_pd = r'$|P_{i}D|$'
csv_files_title = [
    "Training Loss(left) \n and Accuracy(right)",
    f"L2 Distance {vector_pd}",
    "Test Loss(left) \n and Accuracy(right)",
]

# Helper function to downsample DataFrame
def downsample(df, step=10):
    return df.iloc[::step]

# Function to collect all dataframes under a list of paths
def collect_dataframes(paths):
    collected = {name: [] for name in csv_files}
    for path in paths:
        assert os.path.exists(path)
        for subfolder in os.listdir(path):
            full_subfolder_path = os.path.join(path, subfolder)
            if os.path.isdir(full_subfolder_path):
                for file_name in csv_files:
                    file_path = os.path.join(full_subfolder_path, file_name)
                    if os.path.exists(file_path):
                        print(f"loading {file_path}")
                        df = pd.read_csv(file_path)
                        collected[file_name].append(df)
    return collected

# Function to compute mean DataFrames by columns
def compute_mean_std(dfs, log_transform=False):
    if not dfs:
        return None, None
    merged = pd.concat(dfs, axis=0)
    if log_transform:
        merged = merged.copy()
        for col in merged.columns:
            if col != 'tick':
                merged[col] = merged[col].apply(lambda x: np.log10(x) if x > 0 else np.nan)
    mean_df = merged.groupby('tick').mean().rolling(window=10, min_periods=1).mean()
    std_df = merged.groupby('tick').std().rolling(window=10, min_periods=1).mean()
    return mean_df, std_df

if __name__ == "__main__":

    # Collect and compute means and stds for each group
    group_stats = {}
    for group, paths in data_paths.items():
        dfs = collect_dataframes(paths)
        group_stats[group] = {}
        for name in csv_files:
            log_transform = (name == "weight_difference_l2.csv")
            group_stats[group][name] = compute_mean_std(dfs[name], log_transform=log_transform)

    # Plotting
    fig, axs = plt.subplots(len(csv_files)-1, len(data_paths), figsize=(12, 6))

    for row_idx, file_name in enumerate(csv_files):
        if row_idx < 2:
            for col_idx, group in enumerate(data_paths.keys()):
                ax = axs[row_idx, col_idx]
                # set yscale to log for L2 distance
                if row_idx == 1:
                    ax.set_yscale("log")
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx])
                mean_df, std_df = group_stats[group][file_name]
                if mean_df is not None and std_df is not None:
                    mean_df = downsample(mean_df, step=compress_ratio)
                    std_df = downsample(std_df, step=compress_ratio)
                if mean_df is not None:
                    for column in mean_df.columns:
                        if col_idx == 0 and ("running_mean" in column or "running_var" in column or "num_batches_tracked" in column):
                            print(f"{column} skipped")
                            continue
                        if col_idx == 1 and ("running_mean" in column or "running_var" in column or "num_batches_tracked" in column):
                            print(f"{column} skipped")
                            continue
                        # if col_idx == 2 and "norm" in column:
                        #     print(f"{column} skipped")
                        #     continue
                        if column != 'tick':
                            if file_name == "weight_difference_l2.csv":
                                mean = 10 ** mean_df[column]
                                lower = 10 ** (mean_df[column] - std_df[column])
                                upper = 10 ** (mean_df[column] + std_df[column])
                                ax.set_yscale('log')
                            else:
                                mean = mean_df[column]
                                lower = mean_df[column] - std_df[column]
                                upper = mean_df[column] + std_df[column]
                            ax.plot(mean_df.index, mean, label="training loss" if row_idx==0 else column, linewidth=1)
                            ax.fill_between(mean_df.index, lower, upper, alpha=0.2)
                
                ax.set_xlabel(r'iteration $i$')
                ax.set_xlim(mean_df.index.min(), mean_df.index.max())
                ax.grid(True)
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
                if row_idx == 0:
                    ax.set_title(data_path_title[group])
                    ax.set_ylim(bottom=0)

                    # special setup for training_accuracy.csv
                    training_accuracy_csv = training_accuracy_csv_paths[group]
                    training_accuracy_df = pd.read_csv(training_accuracy_csv)
                    training_accuracy_df = downsample(training_accuracy_df, step=compress_ratio)
                    training_accuracy_df = training_accuracy_df.groupby('tick').mean().rolling(window=10, min_periods=1).mean()
                    ax2 = ax.twinx()
                    ax2.plot(training_accuracy_df.index, training_accuracy_df["0"], linestyle='dotted', label="training accuracy", linewidth=1, color="C1")
                    h1, l1 = ax.get_legend_handles_labels()
                    h2, l2 = ax2.get_legend_handles_labels()
                    ax2.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                    ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                    ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                    if col_idx == 0:
                        ax2.set_ylim(top=1.005)
                    elif col_idx == 1:
                        ax2.set_ylim(top=1.0)
                    elif col_idx == 2:
                        ax2.set_ylim(top=1.01)
                    else:
                        raise NotImplementedError
        elif row_idx == 2: # test loss and accuracy
            for col_idx, group in enumerate(data_paths.keys()):
                ax = axs[row_idx, col_idx]
                # set title
                if col_idx == 0:
                    ax.set_ylabel(csv_files_title[row_idx])

                mean_df_loss, std_df_loss = group_stats[group]["loss.csv"]
                mean_df_acc, std_df_acc = group_stats[group]["accuracy.csv"]
                if mean_df_loss is not None and std_df_loss is not None:
                    mean_df_loss = downsample(mean_df_loss, step=compress_ratio)
                    std_df_loss = downsample(std_df_loss, step=compress_ratio)
                if mean_df_acc is not None and std_df_acc is not None:
                    mean_df_acc = downsample(mean_df_acc, step=compress_ratio)
                    std_df_acc = downsample(std_df_acc, step=compress_ratio)

                if mean_df_loss is not None and std_df_loss is not None:
                    first = True
                    for column in mean_df_loss.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_loss[column]
                        lower = mean_df_loss[column] - std_df_loss[column]
                        upper = mean_df_loss[column] + std_df_loss[column]
                        print(f"print column {column}")
                        label = "test loss" if first else None
                        ax.plot(mean_df_loss.index, mean, label=label, linewidth=1, color="C0")
                        ax.fill_between(mean_df_loss.index, lower, upper, alpha=0.2, color="C0")
                ax2 = ax.twinx()
                if mean_df_acc is not None and std_df_acc is not None:
                    first = True
                    for column in mean_df_acc.columns:
                        if column == 'tick' or column == 'phase':
                            continue
                        mean = mean_df_acc[column]
                        lower = mean_df_acc[column] - std_df_acc[column]
                        upper = mean_df_acc[column] + std_df_acc[column]
                        print(f"print column {column}")
                        label = "test accuracy" if first else None
                        ax2.plot(mean_df_acc.index, mean, linestyle='dotted', label=label, linewidth=1, color="C1")
                        ax2.fill_between(mean_df_acc.index, lower, upper, alpha=0.2, color="C1")
                ax.set_xlim(mean_df_loss.index.min(), mean_df_loss.index.max())
                ax.set_xlabel(r'iteration $i$')
                ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
                ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                if col_idx == 0:
                    ax.set_ylim(top=0.6)
                    ax2.set_ylim(top=0.96)
                if col_idx == 1:
                    ax.set_ylim(top=0.6)
                    ax2.set_ylim(top=0.96)
                if col_idx == 2:
                    ax.set_ylim(top=0.6)
                    ax2.set_ylim(top=0.94)

                h1, l1 = ax.get_legend_handles_labels()
                h2, l2 = ax2.get_legend_handles_labels()
                ax2.legend(h1 + h2, l1 + l2, loc='upper right', fontsize="8")
                ax.grid(True)


    plt.tight_layout(pad=0.5)
    plt.savefig("m2m_main.jpg", pad_inches=0, dpi=400)
    plt.savefig("m2m_main.pdf", pad_inches=0)
