from collections import defaultdict
import contextlib
from dataclasses import dataclass
import json
import os
import time
from torch.utils.tensorboard import SummaryWriter

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import polars as pl
from pathlib import Path
import torch as th
import best3 as b3
import itertools
import torch.multiprocessing as mp
import mininterface
import datetime as dt
import uuid
from math import sqrt
from tqdm import tqdm

device = th.device("cuda")
non_blocking = device.type == "cuda"


class MFModel(th.nn.Module):
    def __init__(self, emb_size, n_customers, n_alternatives, sparse):
        super().__init__()
        self.user_emb = th.nn.Embedding(n_customers, emb_size, sparse=sparse)
        self.item_emb = th.nn.Embedding(n_alternatives, emb_size, sparse=sparse)
        self.user_bias = th.nn.Embedding(n_customers, 1, sparse=sparse)
        self.item_bias = th.nn.Embedding(n_alternatives, 1, sparse=sparse)

        # init
        with th.no_grad():
            std = 1.0 / sqrt(emb_size)
            th.nn.init.normal_(self.user_emb.weight, mean=0.0, std=std)
            th.nn.init.normal_(self.item_emb.weight, mean=0.0, std=std)
            th.nn.init.zeros_(self.user_bias.weight)
            th.nn.init.zeros_(self.item_bias.weight)

    def forward(self, users, items):
        ue = self.user_emb(users)  # [B, d]
        ie = self.item_emb(items)  # [B, d]
        ub = self.user_bias(users).squeeze(-1)  # [B]
        ib = self.item_bias(items).squeeze(-1)  # [B]
        dot = (ue * ie).sum(-1)  # [B]
        return dot + ub + ib


@dataclass(kw_only=True)
class Conf:
    base_dir: str
    batch_size: int = 10240
    test_batch_size: int = 10240
    emb_size: int = 128
    n_workers: int = min(b3.cpu_count(), 8)
    sparse: bool = False
    run_name: str = ""
    epochs: int = 10
    batches_per_epoch: int = 100
    eval_patches_per_epoch: int = 100
    output_dir: str | None = None


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    conf: Conf = mininterface.run(Conf, title="B3").env  # type: ignore

    base_dir = Path(conf.base_dir)
    output_dir = (
        (
            Path("exps") / Path(conf.base_dir)
            if conf.output_dir is None
            else Path(conf.output_dir)
        )
        / f"{conf.run_name}{dt.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}-{uuid.uuid1()}"
    )
    output_dir.mkdir(parents=True)
    (output_dir / "models").mkdir()

    train_ds = b3.Dataset(
        base_dir,
        ranks=[1],
        batch_size=conf.batch_size,
        device="cpu",
        ret_embed=False,
    )

    test_ds = b3.ValDataset(
        base_dir,
        batch_size=conf.test_batch_size,
        device="cpu",
        ret_embed=False,
    )

    train_ds.lazy_init()
    test_ds.lazy_init()

    train_dl = iter(
        th.utils.data.DataLoader(
            train_ds,
            batch_size=None,
            num_workers=8,
            collate_fn=b3.identity,
            pin_memory=True,
        )
    )

    test_dl = iter(
        th.utils.data.DataLoader(
            test_ds,
            batch_size=None,
            num_workers=8,
            collate_fn=b3.identity,
            pin_memory=True,
        )
    )

    n_customers = len(train_ds.customer_prob)
    n_alternatives = len(train_ds.alternatives)
    model = MFModel(conf.emb_size, n_customers, n_alternatives, sparse=conf.sparse).to(
        device
    )
    model.compile(mode="max-autotune", dynamic=False)
    opt = (
        th.optim.SparseAdam(model.parameters(), lr=0.01)
        if conf.sparse
        else th.optim.Adam(model.parameters(), lr=0.01, fused=True)
    )

    c: th.Tensor
    idx: th.Tensor
    val_res: dict[str, float] = {}
    writer = SummaryWriter(output_dir)
    with contextlib.closing(writer):
        for epoch in range(conf.epochs):
            ratings = th.tensor(test_ds.rating_of, device=device)
            alternative_index_of = test_ds.alternative_index_of.to(device)
            val_acc: defaultdict[str, b3.Acc] = defaultdict(b3.Acc)
            with th.no_grad():
                for i, (c, idx) in tqdm(
                    enumerate(itertools.islice(test_dl, conf.eval_patches_per_epoch)),
                    total=conf.eval_patches_per_epoch,
                ):
                    c = c.to(device, non_blocking=non_blocking)
                    idx = idx.to(device, non_blocking=non_blocking)
                    idx = idx.squeeze(-1)
                    alt_idx = alternative_index_of[idx]
                    rat_idx = ratings[idx]

                    c0, c1 = th.tensor(0), th.tensor(0)
                    b = 2048
                    for i in range(0, n_customers, b):
                        i1 = min(n_customers, i + b)
                        x = model(
                            th.arange(i, i1, device=device)[:, None, None],
                            alt_idx[None],
                        )
                        filt = b3.apply_constraint(test_ds.constraints[:-1], x)
                        c0 = c0 + filt.sum(0)

                        filt = th.logical_and(filt, b3.apply_constraint(test_ds.constraints[-1:], x))
                        c1 = c1 + filt.sum(0)
                    prob = th.where(c0 == 0, 0.5, c1 / c0)
                    for k, vv in {
                        "acc": th.where(
                            prob != 0, prob > 0.5, th.randn_like(prob) >= 0.5
                        ).float(),
                        "brier": (prob - 1).pow(2),
                        "nll": -prob.clamp(0.001, 0.999).log(),
                    }.items():
                        val_acc[k].push(vv.mean(), vv.pow(2).mean(), len(vv))

            val_res = {}
            for k, v in val_acc.items():
                std, err = map(float, v.val)
                val_res[k] = float(v.v)
                val_res[f"{k}_std"] = std
                val_res[f"{k}_err"] = err
            for k, v in val_res.items():
                writer.add_scalar(f"/val/{k}", v, global_step=epoch)
            print(dt.datetime.now(), "test", epoch, i, val_res)

            ratings = th.tensor(train_ds.rating_of, device=device)
            alternative_index_of = train_ds.alternative_index_of.to(device)
            avg_loss = b3.Acc()

            for i, (c, idx) in tqdm(
                enumerate(itertools.islice(train_dl, conf.batches_per_epoch), start=1),
                total=conf.batches_per_epoch,
            ):
                c = c.to(device, non_blocking=non_blocking)
                idx = idx.squeeze(-1).to(device, non_blocking=non_blocking)
                alt_idx = alternative_index_of[idx]
                rat_idx = ratings[idx]
                loss = th.nn.functional.huber_loss(
                    model(c, alt_idx), rat_idx.float(), reduction="mean"
                )
                opt.zero_grad()
                loss.backward()
                opt.step()
                loss = loss.detach()
                avg_loss.push(loss, loss * loss, 1)

                if avg_loss.n >= 100 or i == conf.batches_per_epoch:
                    print(
                        dt.datetime.now(), "train_loss", epoch, i, "avg_loss", avg_loss
                    )
                    avg_loss = avg_loss = b3.Acc()
            train_acc: dict[str, b3.Acc] = {"loss": avg_loss}
            for k, v in train_acc.items():
                std, err = v.val
                writer.add_scalar(f"/train/{k}", v.v, global_step=epoch)
                writer.add_scalar(f"/train/{k}_std", std, global_step=epoch)
                writer.add_scalar(f"/train/{k}_err", err, global_step=epoch)

            th.save(model.state_dict(), output_dir / "models" / "model-{epoch:05d}.pkl")
    with open(output_dir / "done", "w") as fd:
        json.dump(val_res, fd, indent=2)
