import itertools
import typing as tp
from pathlib import Path
import h5py
import numpy as np
import polars as pl
from scipy import special
from torch.utils.data import IterableDataset
import torch as th
import best3 as b3


def identity(x):
    return x


def make_pointer(x):
    x = np.asarray(x)
    if x.ndim == 1:
        x = x[..., None]
    (y,) = np.where((np.diff(x, axis=0) != 0).any(-1))
    r = np.concatenate(
        [np.zeros_like(y, shape=(1,)), 1 + y, np.full_like(y, len(x), shape=(1,))]
    )
    assert np.all(r >= 0)
    assert np.all(r[:-1] < len(x))

    return r


class Dataset(IterableDataset[tuple[th.Tensor, th.Tensor]]):
    def __init__(
        self,
        base_dir: Path,
        ranks: tp.Sequence[int],
        cat: str = "train",
        batch_size: int = 1024,
        device: str | th.device = "cpu",
        ret_embed: bool = True,
        choices: pl.DataFrame | None = None,
    ):
        self.inited = False
        self.ret_embed = ret_embed
        self.__ranks = tuple(ranks)
        self.sp = list(itertools.accumulate(self.__ranks, initial=0))
        self.__rank = sum(self.__ranks)
        
        self.base_dir = base_dir
        self.batch_size = batch_size
        self.device = th.device(device)
        self.cat = cat
        self._choices = choices

        self.constraints = [
            tuple(range(i0, i0 + i)) for i0, i in zip(self.sp[:-1], ranks, strict=True)
        ]

    def lazy_init(self) -> None:
        if self.inited:
            return

        self.alternatives = pl.DataFrame.deserialize(
            self.base_dir / "alternatives.pl"
        ).with_row_index("alternative_index")
        self.choices = (
            pl.DataFrame.deserialize(self.base_dir / f"{self.cat}-choices.pl")
            if self._choices is None
            else self._choices
        )
        self.choices = self.choices.join(self.alternatives, on="alternative").sort(
            ("customer", "rating")
        )

        self.customer_of = self.choices["customer"].to_numpy()
        self.rating_of = self.choices["rating"].to_numpy()
        self.alternative_index_of = th.tensor(
            self.choices["alternative_index"].to_numpy(),
            device=self.device,
            dtype=th.long,
        )

        self.customer_pointer = make_pointer(self.customer_of)

        self.customer_prob = np.diff(self.customer_pointer) >= self.__rank
        self.customer_prob = self.customer_prob / self.customer_prob.sum()

        assert np.isfinite(self.customer_prob).all()

        with h5py.File(self.base_dir / "embed.hdf5", "r") as emb_f:
            self.embed = th.tensor(np.array(emb_f["embed"]), device=self.device)  # type: ignore

        self.rng = np.random.default_rng()
        self.inited = True

    def sample_one(self) -> tuple[np.ndarray, np.ndarray]:
        c = self.rng.choice(
            len(self.customer_prob), (self.batch_size,), p=self.customer_prob
        )

        idx = self.customer_pointer[c, None] + np.stack(
            [
                self.rng.choice(
                    self.customer_pointer[i + 1] - self.customer_pointer[i],
                    self.__rank,
                    replace=False,
                    shuffle=True,
                )
                for i in c
            ]
        )

        idx = np.concat(
            [
                np.sort(idx[..., i:j], axis=-1, stable=True)[..., ::-1]
                for i, j in itertools.pairwise(self.sp)
            ],
            axis=-1,
        )
        ret = (c, idx)
        for i in ret:
            assert len(ret[0]) == len(i)

        return ret

    def __iter__(self):
        self.lazy_init()
        while True:
            ret = tuple(th.tensor(i, device=self.device) for i in self.sample_one())
            if self.ret_embed:
                *ret1, idx = ret
                yield (*ret1, self.embed[self.alternative_index_of[idx]])
            else:
                yield ret


class ValDataset(Dataset):
    def __init__(
        self,
        base_dir: Path,
        rank: int = 4,
        pred: int = 2,
        cat: str = "val",
        batch_size: int = 1024,
        device: str | th.device = "cpu",
        ret_embed: bool = True,
        choices: pl.DataFrame | None = None,
    ):
        super().__init__(
            base_dir=base_dir,
            ranks=(rank, pred),
            cat=cat,
            batch_size=batch_size,
            device=device,
            ret_embed=ret_embed,
            choices=choices,
        )
        self.constraints: list[tuple[int, ...]] = [
            tuple(range(rank)),
            tuple(range(rank, rank + pred)),
        ]
        self.__rank = rank

    # def sample_one(self):
    #     c, idx = super().sample_one()
    #     a = self.rng.choice(idx[..., 0 : self.__rank], axis=-1)
    #     b = self.rng.choice(idx[..., self.__rank : 2 * self.__rank], axis=-1)
    #     idx = np.concat(
    #         [
    #             idx,
    #             np.sort(
    #                 np.stack([a, b], axis=-1),
    #                 axis=-1,
    #                 stable=True,
    #             )[..., ::-1],
    #         ],
    #         axis=-1,
    #     )

    #     return c, idx


def eval_prob(prob: th.Tensor):
    _val_res = {
        "acc": (
            th.where(prob == 0.5, th.rand(*prob.shape, device=prob.device), prob) >= 0.5
        ).float(),
        "brier": th.pow(prob - 1, 2),
        "nll": -th.log(th.clamp(prob, 0.001, 0.999)),
    }
    val_res = {}
    for k, s in _val_res.items():
        acc = b3.Acc()
        acc.push(s.mean(), th.pow(s, 2).mean(), len(s))
        std, err = map(float, acc.val)
        val_res[k] = float(acc.v)
        val_res[f"{k}_std"] = std
        val_res[f"{k}_err"] = err
    return val_res


def save_dataset(
    data_dir: Path,
    mdf: pl.DataFrame,
    df: pl.DataFrame,
    emebed: np.ndarray | None = None,
):
    if emebed is None:
        emebed = np.eye(len(mdf))
    df = df.filter(pl.col.alternative.is_in(mdf["alternative"].implode()))
    print(mdf)
    print(df)
    data_dir.mkdir(exist_ok=True, parents=True)
    uc = df["customer"].unique().shuffle()
    val_uc = uc[: len(uc) // 10]
    train_uc = uc[len(uc) // 10 :]

    mdf.serialize(data_dir / "alternatives.pl")
    df.serialize(data_dir / "choices.pl")

    pl.DataFrame({"customer": train_uc}).join(df, "customer").serialize(
        data_dir / "train-choices.pl"
    )
    pl.DataFrame({"customer": val_uc}).join(df, "customer").serialize(
        data_dir / "val-choices.pl"
    )
    with (
        h5py.File(data_dir / "embed.hdf5", "a") as f,
    ):
        if "embed" in f:
            del f["embed"]
        assert len(emebed) == len(mdf)

        f.create_dataset("embed", emebed.shape, dtype=np.float32)
        f["embed"][...] = emebed


class Acc:
    def __init__(self):
        self.v = th.tensor(0, requires_grad=False)
        self.v2 = th.tensor(0, requires_grad=False)
        self.n = th.tensor(0, dtype=th.int64, requires_grad=False)

    def mts(self, v) -> th.Tensor:
        if isinstance(v, th.Tensor):
            v = v.detach()
        else:
            v = th.tensor(v)
        return v.requires_grad_(False)

    def push(self, v, v2, n) -> None:
        n = self.mts(n)
        self.n += n
        alpha = n / self.n
        self.v = self.v * (1 - alpha) + self.mts(v) * alpha
        self.v2 = self.v2 * (1 - alpha) + self.mts(v2) * alpha

    @property
    def val(self) -> tuple[th.Tensor, th.Tensor]:
        s = th.sqrt(th.abs(self.v2 - self.v * self.v))
        e = s / th.sqrt(self.n)
        return s, e

    def __repr__(self) -> str:
        s, e = self.val
        return f"{self.v} += {s} {e} from {self.n} points"
