from overtraining.plotting.shared import *
from overtraining.plotting.figure1 import figure1
from overtraining.plotting.emperical import emperical
from overtraining.plotting.error import prediction_error
from overtraining.plotting.error_vs import error_vs
from overtraining.plotting.error_vs_count import error_vs_count
from overtraining.plotting.constants import *
from overtraining.plotting.grid import *
from overtraining.plotting.grid_color import *
from overtraining.plotting.downstream_corr import *
from overtraining.plotting.error_downstream import *
from overtraining.plotting.slopes import *
from overtraining.plotting.downstream_corr_all import *
from overtraining.plotting.non_trivial import *
from overtraining.plotting.downstream_emperical import *


def plot_all():
    root = "plots"
    if not os.path.exists(root):
        os.mkdir(root)

    # figure 1
    dataset = "rpj"
    evalset = "c4_val"
    figure1(val_suffix=evalset, dataset=dataset)
    plt.savefig(f"{root}/figure1_{dataset}_{evalset}.pdf")
    plt.close()

    # relative error vs. compute
    error_vs()
    plt.savefig(f"{root}/error_vs_1b.pdf")
    plt.close()

    error_vs(
        target_N="open_lm_7b",
        target_M=1.0,
    )
    plt.savefig(f"{root}/error_vs_7b.pdf")
    plt.close()

    non_trivial()
    plt.savefig(f"{root}/eval_ablation.pdf")
    plt.close()

    downstream_corr_all()
    plt.savefig(f"{root}/downstream_corr_all.pdf")
    plt.close()

    slopes()
    plt.savefig(f"{root}/slopes.pdf")
    plt.close()

    grid_full_plot_color()
    plt.savefig(f"{root}/grid_full_color.pdf")
    plt.close()

    # emperical phenomenon
    emperical()
    plt.savefig(f"{root}/emperical.pdf")
    plt.close()

    downstream_emperical()
    plt.savefig(f"{root}/downstream_emperical.pdf")
    plt.close()

    prediction_error_downstream()
    plt.savefig(f"{root}/error_downstream_all.pdf")
    plt.close()

    prediction_error_downstream(push="avg_subset")
    plt.savefig(f"{root}/error_downstream_subset.pdf")
    plt.close()

    # small mults
    emperical(
        cc_mults=[0.25, 0.5, 1.0, 2.0, 4.0],
        dataset_val_pairs=[("c4_original", "c4_val"), ("rpj", "c4_val"), ("rw_original", "c4_val")],
        compute_range=[3e14, 3e17, 3e18, 3e19, 3e20],
    )
    plt.savefig(f"{root}/emperical_small.pdf")
    plt.close()

    # prediction error across many
    prediction_error(eval_dir=None)
    plt.savefig(f"{root}/error.pdf")
    plt.close()

    # OOD prediction error
    prediction_error(
        train_val_pairs=[
            ("c4_original", "paloma_dolma_100_programing_languages"),
            ("c4_original", "paloma_ptb"),
            ("c4_original", "val_de-en_100"),
        ],
        eval_dir=None,
    )
    plt.savefig(f"{root}/error_ood.pdf")
    plt.close()

    # ID prediction error
    prediction_error(
        train_val_pairs=[
            ("c4_original", "paloma_c4_en"),
            ("rpj", "paloma_redpajama"),
            ("rw_original", "paloma_falcon-refinedweb"),
        ],
        eval_dir=None,
    )
    plt.savefig(f"{root}/error_id.pdf")
    plt.close()

    downstream_corr()
    plt.savefig(f"{root}/downstream_corr.pdf")
    plt.close()

    downstream_corr(
        x_axis=["loss_paloma_c4_en", "loss_paloma_redpajama", "loss_paloma_falcon-refinedweb"],
        y_axis=["err_avg_subset", "err_avg_subset", "err_avg_subset"],
        x_labels=["Loss: C4", "Loss: RedPajama", "Loss: RefinedWeb"],
        y_labels=["Average top-1 error: 17-task split"] * 3,
    )
    plt.savefig(f"{root}/downstream_corr_ablation.pdf")
    plt.close()

    downstream_corr(
        x_axis=["loss_paloma_c4_en", "loss_paloma_redpajama", "loss_paloma_falcon-refinedweb"],
        y_axis=["err_avg", "err_avg", "err_avg"],
        x_labels=["Loss: C4", "Loss: RedPajama", "Loss: RefinedWeb"],
        y_labels=["Average top-1 error: 46-task split"] * 3,
    )
    plt.savefig(f"{root}/downstream_corr_ablation_all.pdf")
    plt.close()

    grid_full_plot()
    plt.savefig(f"{root}/grid_full.pdf")
    plt.close()

    # relative error vs. compute
    error_vs_count()
    plt.savefig(f"{root}/error_vs.pdf")
    plt.close()


plot_all()
