from collections import defaultdict

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import yaml

from ml_utils.proxies import set_proxies

# %% Setup env
with open("./scripts/paper/utils/models.yml", encoding="utf-8") as f:
    models_info = yaml.safe_load(f)

URI = "azureml://db17de6f-6cb8-4996-849f-3fdf0d10a4b9.workspace.westeurope.api.azureml.ms/mlflow/v2.0/subscriptions/dec849c6-3664-4372-9569-749eb6820434/resourceGroups/rg-VLProd/providers/Microsoft.MachineLearningServices/workspaces/mlw-deeplearning1h7w"
JOB_NAME_F1 = "250801004311_aaai26_f1"
JOB_NAME_ACC = "250801005533_aaai26_acc"


def extract_scores(metrics, suffix):
    return {k.removesuffix(f"_{suffix}"): v for k, v in metrics.items() if k.endswith(f"_{suffix}")}


# %% Retrieve eval data
with set_proxies():
    mlflow.set_tracking_uri(URI)
    run = mlflow.get_run(JOB_NAME_F1)
metrics = run.data.metrics

# %% Write F1 scores (clf datasets)

scores = extract_scores(metrics, "f1")
scores_nested = defaultdict(dict)
for k, v in scores.items():
    mdl, ds = k.split("_", 1)
    scores_nested[mdl][ds] = v

df = pd.DataFrame.from_dict(scores_nested, orient="index")
df = df.drop(columns=["clf"])
df = df.loc[:, ~df.columns.str.contains("nli")]
df = df.loc[models_info["model_to_type"].keys()]
df = df.sort_index(axis=1)

df.to_csv("./scripts/paper/utils/f1_scores.csv")


# %% Write ROC scores (nli datasets)

scores = extract_scores(metrics, "roc")
scores_nested = defaultdict(dict)
for k, v in scores.items():
    mdl, ds = k.split("_", 1)
    scores_nested[mdl][ds] = v

df = pd.DataFrame.from_dict(scores_nested, orient="index")
df = df.drop(columns=["nli"])
df = df.loc[:, df.columns.str.contains("nli")]
df = df.loc[models_info["model_to_type"].keys()]
df = df.sort_index(axis=1)


df.to_csv("./scripts/paper/utils/roc_scores.csv")


# %% Write Acc scores (clf datasets)

with set_proxies():
    mlflow.set_tracking_uri(URI)
    run = mlflow.get_run(JOB_NAME_ACC)
metrics = run.data.metrics


scores = extract_scores(metrics, "acc")
scores_nested = defaultdict(dict)
for k, v in scores.items():
    mdl, ds = k.split("_", 1)
    scores_nested[mdl][ds] = v

df = pd.DataFrame.from_dict(scores_nested, orient="index")
df = df.drop(columns=["clf"])
df = df.loc[:, ~df.columns.str.contains("nli")]
df = df.loc[models_info["model_to_type"].keys()]
df = df.sort_index(axis=1)

df.to_csv("./scripts/paper/utils/acc_scores.csv")
