from dataclasses import dataclass
import json
import polars as pl
import functools as ft
import itertools as it
import operator
import torch as th
import best3 as b3
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
from loguru import logger
import sys
import mininterface
import uuid
import datetime as dt


@dataclass
class Conf:
    base_dir: Path
    output_dir: str | None = None
    optimal: bool = False
    run_name: str = ""


conf: Conf = mininterface.run(Conf, title="Optimal").env  # type: ignore


output_dir = (
    Path("exps") / Path(conf.base_dir)
    if conf.output_dir is None
    else Path(conf.output_dir)
) / f"{conf.run_name}{dt.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}-{uuid.uuid1()}"
output_dir.mkdir(parents=True)
print(pl.DataFrame.deserialize(conf.base_dir / "alternatives.pl"))
df = pl.DataFrame.deserialize(conf.base_dir / "train-choices.pl")
logger.info("df={}", len(df))


test_ds = b3.ValDataset(
    conf.base_dir,
    batch_size=100,
    ret_embed=False,
)


def colname(x):
    if x > 0:
        return str(x)
    return ""


def extract(_df: pl.DataFrame, constraint: tuple[int, ...], query: np.ndarray):
    _df = _df.filter(pl.col.alternative.is_in(query))
    df = _df.filter(pl.col.alternative == query[constraint[0]])
    for i in range(1, len(constraint)):
        df = df.join(
            _df.filter((pl.col.alternative == query[constraint[i]])),
            on="customer",
            suffix=str(i),
        ).filter((pl.col("rating" + colname(i - 1)) != pl.col("rating" + colname(i))))
    fdf = df
    for i in range(1, len(constraint)):
        fdf = fdf.filter(
            pl.col("rating" + colname(i - 1)) > pl.col("rating" + colname(i))
        )

    return fdf["customer"].unique(), df["customer"].unique()


print(test_ds.constraints)


def classify(optimal, df, obs):
    obs = tuple(map(int, obs))
    df = df.filter(pl.col.alternative.is_in(obs))

    if optimal:
        for cond in test_ds.constraints[:-1]:
            df = df.filter(pl.col.customer.is_in(extract(df, cond, obs)[0].implode()))
    x, y = map(len, extract(df, test_ds.constraints[-1], obs))
    return x / y if y != 0 else 0.5


solvers = {"uncond": ft.partial(classify, False)}
if conf.optimal:
    solvers["cond"] = ft.partial(classify, True)
test_ds.lazy_init()


ress: dict[str, list[float]] = {k: [] for k in solvers}
for rs in it.islice(test_ds, 25):
    for r in rs[-1].detach().numpy():
        a = test_ds.choices[r]["alternative"]
        for k, s in solvers.items():
            ress[k].append(s(df, a))

    for k, res in ress.items():
        res = th.tensor(res)
        print(k, res)
        val_res = b3.eval_prob(res)

        logger.info("solver={} {}", k, val_res)
with open(output_dir / "done", "w") as fd:
    json.dump(val_res, fd, indent=2)
