from dataclasses import dataclass
from pathlib import Path
import shutil
from matplotlib import pyplot as plt
import polars as pl
import h5py
import numpy as np
from scipy import special
from scipy.stats import multivariate_normal
import itertools as it
import best3 as b3
import yaml
import best3 as b3
import torch as th


@dataclass
class Model:
    mu: np.ndarray
    si: np.ndarray
    embed: np.ndarray
    N: int  # num of observations
    M: int  # num of observations per customer
    name: str


def pd(x):
    N, M = x.shape
    return x @ x.T + np.eye(N) * 1e-8


def entropy(x):
    y = x / x.sum()
    return -(y * np.log(x)).sum()


def run(rng: np.random.Generator, i: int, model: Model):
    model.mu -= np.mean(model.mu)
    D = model.mu.shape[-1]

    print(model.name)
    dir = Path("datasets") / "syn" / model.name / str(i)
    shutil.rmtree(dir, True)
    dir.mkdir(parents=True, exist_ok=True)
    with h5py.File(dir / "embed.hdf5", "w") as f:
        f.create_dataset("embed", model.embed.shape, dtype=np.float32)
        f["embed"][...] = model.embed  # type: ignore
    dist = multivariate_normal(mean=model.mu, cov=model.si)  # type: ignore

    mdf = pl.DataFrame(
        {"alternative": np.arange(D), "name": [str(i) for i in range(D)]}
    )
    mdf.serialize(dir / "alternatives.pl")

    for cat in ["train", "val"]:
        data = dist.rvs((model.N))
        np.save(dir / f"{cat}-util.npz", data)
        perm = np.argsort(rng.uniform(0, 1, (model.N, D)), axis=-1)[:, : model.M]
        sdata = data[np.arange(model.N)[:, None], perm]
        rank = perm[np.arange(model.N)[:, None], np.argsort(sdata)[:, ::-1]]
        df = pl.DataFrame(
            {
                "customer": np.arange(model.N).repeat(model.M),
                "rating": np.tile(np.arange(model.M)[::-1], model.N),
                "alternative": rank.flatten(),
            }
        )
        df.serialize(dir / f"{cat}-choices.pl")

    np.savez(dir / "sol.npz", mu=model.mu, si=model.si)
    rum = b3.Probit3(th.tensor(model.mu).float(), th.tensor(model.si).float())
    fig = rum.viz()
    fig.savefig(dir / "basliene.pdf", bbox_inches="tight", transparent=True)
    plt.close(fig)


def main():
    rng = np.random.default_rng()
    N = 100_000
    
    for D in [8]:
        M = D
        for i in range(3):
            for mu, mu_n in [
                (rng.uniform(-1, 1, (D,)), "r"),
                (np.zeros((D,)), "0"),
            ]:
                for si, si_n in [
                    (np.eye(D), "I"),
                    (np.diag(rng.uniform(0.5, 1, (D,))), "rI"),
                    (pd(rng.uniform(-1, 1, (D, D))), "r"),
                    (
                        pd(
                            np.concat(
                                [
                                    np.ones(D // 2)[:, None],
                                    -np.ones(D // 2)[:, None],
                                ]
                            )
                        ),
                        "bin",
                    ),
                ]:
                    run(rng, i, Model(
                            mu=mu,
                            si=si,
                            embed=np.eye(D),
                            N=N,
                            M=M,
                            name=f"mu_{mu_n}-si_{si_n}_{D}",
                        ),)


main()
