# %%
import json
import time

import mlflow
import torch
from llmapi import gpt

from clcp.data import CLF_TEST_SUBSETS, Data, get_tokenizer
from clcp.metrics import MultiDatasetMetrics

# %%
tok = get_tokenizer()
ds = Data(tok=tok, kind="clf", split="test", is_test=False, is_dummy=False)
ev_metrics = MultiDatasetMetrics(clf_metadata=CLF_TEST_SUBSETS)

# %% flattened
# sys_prompt = """
# You are a text classifier.
# Given a text and a hypothesis, output a single float between 0 and 1 representing the probability that the hypothesis is entailed by the text.
# A value near 1 means strong agreement; near 0 means strong disagreement.
# Do not provide explanations — only return the float.
# """
# usr_prompt_template = """
# <TEXT>\n{text}\n</TEXT>\n
# <HYPOTHESIS>\n{hypothesis}\n</HYPOTHESIS>
# """

# ys, yhs, task_names, exceptions = [], [], [], 0
# gpt.set_system_prompt(sys_prompt)
# for i, sample in enumerate(ds.data):
#     usr_prompt = usr_prompt_template.format(text=sample["text"], hypothesis=sample["hypothesis"])
#     ys.append(sample["labels"])
#     task_names.append(sample["task_name"])

#     for attempt in range(5):
#         try:
#             yhs.append(float(gpt(usr_prompt)))
#             break  # exit retry loop
#         except Exception:
#             if attempt < 5 - 1:
#                 time.sleep(60 + 6**attempt)
#             else:
#                 yhs.append(0.0)
#                 exceptions += 1
#                 print(f"number of exceptions: {exceptions}")
#     if i % 1000 == 0:
#         print(f"{i}/{len(ds)}")

# yhs, ys = torch.tensor(yhs), torch.tensor(ys)
# metrics = ev_metrics(yhs, ys, task_names=task_names, step=0)
# mlflow.log_metrics(metrics=metrics)
# ev_metrics.plot_cm(yhs, ys, task_names=task_names)


# # Log mistakes made
# mistake_indices = torch.nonzero(ys - yhs.round(), as_tuple=True)[0].tolist()
# mistake_log, lower = {}, 0
# for subset, meta_data in CLF_TEST_SUBSETS.items():
#     upper = meta_data["len"]
#     subset_indices = [i for i in mistake_indices if lower <= i < upper]
#     mistake_log[subset] = subset_indices
#     lower = upper
# mlflow.log_dict(mistake_log, "eval_mistakes_gpt.json")


# %% Multiclass
sys_prompt = """
You are a classifier that returns only a JSON list of floats representing probabilities.
For each class. Do not return any explanation, labels, or text."""
usr_prompt_template = """
<TEXT>\n{text}\n</TEXT>\n
<LABELS>\n{labels}\n</LABELS>
"""

df = ds.data.to_pandas()
subsets = df["subset"].unique()
gpt.set_system_prompt(sys_prompt)

ys, yhs, task_names, exceptions = [], [], [], 0
for subset in subsets[1:]:
    mask = df["subset"] == subset
    dfs = df[mask]
    labels = sorted(dfs["label_text"].unique())
    n_labels = len(labels)
    texts = dfs["text"].unique()
    task_names.extend(dfs["task_name"].tolist())
    ys.extend(dfs["labels"].tolist())
    print(f"Processing {subset}...")
    for text in texts:
        for attempt in range(5):
            try:
                preds = json.loads(gpt(usr_prompt_template.format(text=text, labels=labels)))
                assert len(preds) == n_labels, f"Assertion failure: got {len(preds)=} but expected {n_labels=}"
                yhs.extend(preds)
                break  # exit retry loop
            except Exception as e:
                print(e)
                if attempt < 5 - 1:
                    time.sleep(60 + 6**attempt)
                else:
                    yhs.extend([0.0] * n_labels)
                    exceptions += 1
                    print(f"number of exceptions: {exceptions}")
        break
    break

yhs, ys = torch.tensor(yhs), torch.tensor(ys)
metrics = ev_metrics(yhs, ys, task_names=task_names, step=0)
mlflow.log_metrics(metrics=metrics)
ev_metrics.plot_cm(yhs, ys, task_names=task_names)


# Log mistakes made
mistake_indices = torch.nonzero(ys - yhs.sigmoid().round(), as_tuple=True)[0].tolist()
mistake_log, lower = {}, 0
for subset, meta_data in CLF_TEST_SUBSETS.items():
    upper = meta_data["len"]
    subset_indices = [i for i in mistake_indices if lower <= i < upper]
    mistake_log[subset] = subset_indices
    lower = upper
mlflow.log_dict(mistake_log, "./outputs/eval_mistakes_gpt.json")
