from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data
from experiments.fns import do_coeff_analysis
from experiments.fns import get_baselines
from experiments.fns import rename_row
from experiments.fns import label_row

# data_project = "shuf_vecs_cifar"
# data_project = "vecs_cifar"
data_project = "bttvecs_cifar"
# base_project = "lr_baselines"
base_project = "cifar_baselines"
x_var = "cola_flops"
# x_var = "cola_params"
target_var = "train_loss_avg"

df = get_project_data(project=data_project)
df = df[df["state"] == "finished"]
df["vec"] = df.apply(rename_row, axis=1)
df["label"] = df.apply(label_row, axis=1)
do_coeff_analysis(df)
df = df.sort_values(by=target_var, ascending=True)
print(df.head(20))
print(f"There are: {len(df['vec'].unique()):,d} vecs")

exprs = [
    "0.5-0.5-0.0-0.0-0.5-0.5-0.0",
    "0.0-0.5-0.5-0.0-0.5-0.0-0.5",
    "0.21-0.0-0.0-0.79-0.0-0.26-0.74",
    "0.1-0.0-0.0-0.9-0.0-0.19-0.81",
    "0.6-0.0-0.0-0.4-0.0-0.0-1.0",
    "0.01-0.62-0.37-0.0-0.33-0.29-0.38",
    "0.03-0.62-0.35-0.0-0.05-0.67-0.27",
    "0.01-0.45-0.55-0.0-0.42-0.20-0.38",
    "0.35-0.39-0.25-0.0-0.45-0.11-0.43",
    "0.03-0.62-0.35-0.0-0.05-0.67-0.27",
    "0.03-0.01-0.96-0.0-0.25-0.43-0.32",
    "0.03-0.62-0.35-0.0-0.05-0.67-0.27",
    "0.02-0.05-0.93-0.0-0.2-0.44-0.36",
]
exprs = [exp + " (BMM0) Adam" for exp in exprs]
# df = df[df["vec"].isin(exprs)]
df = df.loc[df.groupby(["vec", "width"])[target_var].idxmin()]

dfb = get_baselines(project=base_project, target_var=target_var)
exprs = ["none", "0.0-0.5-0.5-0.0-0.5-0.0-0.5"]
exprs += ["0.5-0.5-0.0-0.0-0.5-0.5-0.0"]
exprs = [exp + " (BMM0) Adam" for exp in exprs]
dfb = dfb[dfb["vec"].isin(exprs)]
do_coeff_analysis(dfb)
dfb["label"] = dfb.apply(label_row, axis=1)

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=75, figsize=(40, 25))
sns.scatterplot(x=x_var, y=target_var, data=dfb, style="label", s=200)
sns.lineplot(x=x_var, y=target_var, data=dfb, style="label")
sns.scatterplot(x=x_var, y=target_var, data=df, style="label", hue="width", s=200)
plt.ylabel("Train Loss" if target_var.startswith("train_loss") else "Error")
plt.xlabel("FLOPs" if x_var == "cola_flops" else "Params")
plt.xscale('log')
plt.yscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
