import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import seaborn as sns
from experiments.fns import fit_scale_law
from experiments.fns import scale_laws_fn

dataset = "cifar100"
file_path = f"./logs/{dataset}_scaling.csv"
runs = pd.read_csv(file_path)
xname = 'cola_flops'

# yname = 'train_error_avg'
# ylabel = 'Train Error'

yname = 'test_error'
ylabel = 'Test Error'

fit = True

# y = 'test_error'
# ylabel = 'Test Error'
struct_offsets = {
    "cifar100": {
        "Dense": 25,
        "BTT": 20,
        "Monarch": 25,
        "Low Rank": 10,
        "Kron": 0,
        "TT": 0,
    },
    "cifar10": {
        "Dense": 15,
        "BTT": 10,
        "Monarch": 15,
        "Low Rank": 0,
        "Kron": 0,
        "TT": 0,
    }
}

for ds in [dataset]:
    filters = {
        "state": "finished",
        "config.layers": {
            "$in": ['all_but_last']
        },  # , 'intermediate', 'ffn', 'attn'
        # "config.dataset": {"$eq": 'cifar100'},
        "config.model": {
            "$eq": 'MLP'
        },
        "config.use_wrong_mult": False,
        # "config.cola_flops": {"$lt": 2e7},
    }

    # log log scale, scatter plot
    sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
    hue_order = ['Dense', 'BTT', 'Monarch', 'Low Rank', 'Kron', 'TT']
    pallette = sns.color_palette("tab10", n_colors=len(hue_order))

    plt.figure(dpi=100, figsize=(6, 6))
    ax = sns.scatterplot(data=runs, x=xname, y=yname, hue='struct', markers=True, hue_order=hue_order, s=100, palette=pallette)
    if fit:
        # fit a line on the log log scale, for each struct
        slopes = []
        for struct, color in zip(hue_order, pallette):
            struct_runs = runs[runs['struct'] == struct]
            x = struct_runs[xname]
            y = struct_runs[yname]
            theta = fit_scale_law((x.values, y.values), lr=1e-2, n_steps=100, tol=1e-4)
            slopes.append(theta[2])
            y_pred = scale_laws_fn(theta, x.values)
            plt.plot(struct_runs[xname], y_pred, color=color, linestyle='--', linewidth=2, alpha=0.75)
            # x = np.log(x)
            # off = struct_offsets[dataset][struct]
            # y = np.log(y - off)
            # m, b = np.polyfit(x, y, 1)
            # slopes.append(-m)
            # plt.plot(struct_runs[xname], np.exp(m * x + b) + off, color=color, linestyle='--', linewidth=2, alpha=0.75)
    plt.ylabel(ylabel)
    plt.xlabel('FLOPs')
    plt.xscale('log')
    plt.yscale('log')
    # yticks as plain numbers (not scientific notation)
    plt.gca().yaxis.set_major_formatter(ScalarFormatter())
    plt.gca().yaxis.set_minor_formatter(ScalarFormatter())

    ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)
    handles, labels = ax.get_legend_handles_labels()
    # remove legend
    ax.get_legend().remove()
    plt.title(f'MLP on {ds.replace("cifar10", "CIFAR-10")}')
    plt.tight_layout()
    plt.savefig(f'./figures/mlp_{ds}.pdf', bbox_inches='tight')
    plt.show()

    # legend as a separate figure
    legend_fig = plt.figure(figsize=(8, 1))  # Adjust size as needed
    ax_legend = legend_fig.add_subplot(111)
    ax_legend.legend(handles, labels, loc='center', ncol=len(labels))
    ax_legend.axis('off')  # Hide axes
    plt.tight_layout()
    plt.savefig(f'./figures/mlp_{ds}_legend.pdf', bbox_inches='tight')
    plt.show()
    # barplot for slopes
    if fit:
        plt.figure(dpi=120, figsize=(8, 6))
        sns.barplot(x=hue_order, y=slopes, palette=pallette)
        plt.ylabel(r'Exponent')
        # rotate x labels
        plt.xticks(rotation=45)
        plt.title(f'Scaling Exponent on {ds.replace("cifar10", "CIFAR-10")}')
        plt.tight_layout()
        plt.savefig(f'./figures/mlp_{ds}_slopes.pdf', bbox_inches='tight')
        plt.show()
