from collections import defaultdict
import json
import os

os.environ["OMP_NUM_THREADS"] = "1"
import typing
import yaml
import dataclasses as dc
from collections.abc import Iterable
import contextlib
from dataclasses import dataclass
from functools import partial
import os
from pathlib import Path
import uuid
import h5py
from matplotlib import pyplot as plt
import numpy as np
import polars as pl
from scipy import special
from torch.utils.data import IterableDataset
import torch as th
from torch import nn, optim
import torch.multiprocessing as mp
import itertools as it
from tqdm import tqdm
import best3 as b3
from loguru import logger
import mininterface
import datetime as dt
import sys
from torch.utils.tensorboard import SummaryWriter


th.set_float32_matmul_precision("highest")
device = th.device("cuda")
logger.info("device {}", device)
logger.info("OMP_NUM_THREADS {}", os.environ["OMP_NUM_THREADS"])


@dataclass(kw_only=True)
class Conf:
    base_dir: str
    batch_size: int = 1024
    n_workers: int = min(b3.cpu_count(), 8)
    epoch_steps: int = 3000
    epochs: int = 4
    eval_batch_size: int = 128
    eval_samples: int = 300_000
    eval_steps: int = 1000
    net_hidden: int = 1024 * 4
    net_input: int = 0
    net_layers: int = 0
    net_phi: int = 0
    net_skip: bool = True
    net_layer_norm: bool = False
    net_sig_ceof: str = "none"
    net_log_full: bool = True
    net_log_viz: bool = False
    lr: float = 3e-4
    output_dir: None | str = None
    weight_decay: float = 0
    optim: str = "Adam"
    temp: float = 1.0
    rum: str = "probit3"
    rum_mode: str = ""

    train_samples: int = 200
    grad_norm: float = 10
    train_rank: int = 3
    dl_device: bool = True
    run_name: str = ""


class Skipential(nn.Module):
    def __init__(self, *fs: nn.Module) -> None:
        super().__init__()
        self.fs = nn.ModuleList(fs)

    def __call__(self, x: th.Tensor) -> th.Tensor:
        for f in self.fs:
            px = x
            x = f(x)
            if px.shape == x.shape:
                x = px + x

        return x

    def append(self, x: nn.Module):
        self.fs.append(x)

    def __iter__(self) -> typing.Iterator[nn.Module]:
        return iter(self.fs)


def make_head(conf: Conf, output_size) -> nn.Module:
    def mk_block(inp, out):
        x = [nn.Linear(inp, out, device=device), nn.GELU()]
        if conf.net_layer_norm:
            x.append(nn.LayerNorm(out, device=device))
        return nn.Sequential(*x)

    inp = conf.net_input
    layers = []
    if conf.net_layers > 0:
        layers.append(mk_block(conf.net_input, conf.net_hidden))
        inp = conf.net_hidden
    layers.append(
        (Skipential if conf.net_skip else nn.Sequential)(
            *(
                mk_block(conf.net_hidden, conf.net_hidden)
                for _ in range(conf.net_layers - 1)
            )
        )
    )

    return nn.Sequential(
        *layers,
        nn.Linear(
            inp,
            output_size,
            device=device,
        ),
    )


class Net(nn.Module):
    def __init__(self, conf: Conf) -> None:
        super().__init__()
        output_size: list[int]
        match conf.rum:
            case "probit3":
                self.gen = partial(
                    b3.Probit3.from_flat,
                    phi_sig=conf.net_sig_ceof,
                    mode=conf.rum_mode,  # type: ignore
                )
                output_size = [1, conf.net_phi]
            case "logit":
                self.gen = partial(
                    b3.Logit.from_flat,
                    temp=conf.temp,
                )
                output_size = [1]
            case "probit2":
                self.gen = partial(
                    b3.Probit2.from_flat,
                    phi_sig=conf.net_sig_ceof,
                    mode=conf.rum_mode,  # type: ignore
                )
                output_size = [1]
                match conf.rum_mode:
                    case "dot":
                        output_size.append(conf.net_phi)
                    case "spherical":
                        output_size.append(1)
                    case "unit":
                        pass
                    case _:
                        raise NotImplementedError

            case _:
                raise NotImplementedError(conf.rum)

        self.inner = nn.ModuleList([make_head(conf, o) for o in output_size])

    def forward(self, x) -> b3.RUM:
        return self.gen(*(i(x) for i in self.inner))


def take_n[T](n, x: Iterable[T]) -> Iterable[T]:
    for i in x:
        yield i


# fake vmap for debugging
def vmap(f):
    def inner(x, *xs):
        return th.stack([f(x[i], *(j[i] for j in xs)) for i in range(len(x))])

    return inner


def main():
    mp.set_start_method("spawn", force=True)

    conf: Conf = mininterface.run(Conf, title="B3").env  # type: ignore
    if conf.rum_mode == "":
        match conf.rum:
            case "probit3":
                conf.rum_mode = "dot"
            case "probit2":
                conf.rum_mode = "dot"

    output_dir = (
        (
            Path("exps") / Path(conf.base_dir)
            if conf.output_dir is None
            else Path(conf.output_dir)
        )
        / f"{conf.run_name}{dt.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}-{uuid.uuid1()}"
    )
    (output_dir / "models").mkdir(parents=True)
    fig_dir = output_dir / "fig"
    fig_dir.mkdir()
    csv_dir = output_dir / "csv"
    csv_dir.mkdir()
    pl_dir = output_dir / "pl"
    pl_dir.mkdir()
    logger.add(output_dir / "logs.txt")
    base_dir = Path(conf.base_dir)
    with h5py.File(base_dir / "embed.hdf5", "r") as emb_f:
        if conf.net_input == 0:
            conf.net_input = emb_f["embed"].shape[-1]
        if conf.net_phi == 0:
            conf.net_phi = min(emb_f["embed"].shape[0] * 4, 4 * 1024)

        if conf.net_log_full:
            embed = th.tensor(np.array(emb_f["embed"])).pin_memory()
        else:
            embed = None

    # assume conf is inited at this point

    with (output_dir / "conf.yaml").open("w") as fd:
        yaml.safe_dump(dc.asdict(conf), fd)
    writer = SummaryWriter(output_dir)
    train_ds = b3.Dataset(
        base_dir,
        ranks=[conf.train_rank],
        batch_size=conf.batch_size,
        device=th.device("cpu"),
    )
    test_ds = b3.ValDataset(
        base_dir,
        batch_size=conf.batch_size,
        device=th.device("cpu"),
    )
    prefetch = 2 if conf.n_workers > 0 else None
    logger.info(f"command {sys.argv}")
    with open(output_dir / "com", "w") as fd:
        fd.write(" ".join(sys.argv))

    def mk_dl(x: b3.Dataset):
        return map(
            lambda x: x[-1],
            iter(
                th.utils.data.DataLoader(
                    x,
                    batch_size=None,
                    num_workers=conf.n_workers,
                    persistent_workers=conf.n_workers > 0,
                    prefetch_factor=prefetch,
                    collate_fn=b3.identity,
                    pin_memory=x.device.type == "cpu",
                )
            ),
        )

    test_dl = mk_dl(test_ds)

    train_dl = mk_dl(train_ds)
    mdf = pl.DataFrame.deserialize(Path(conf.base_dir) / "alternatives.pl")

    model = Net(conf).to(device)
    logger.info("model {}", model)
    opt: optim.Optimizer
    match conf.optim:
        case "Adam":
            opt = optim.AdamW(
                model.parameters(),
                lr=conf.lr,
                fused=True,
                weight_decay=conf.weight_decay,
            )
        case "SGD":
            opt = optim.SGD(model.parameters(), lr=conf.lr, fused=True, nesterov=True)
        case _:
            raise NotImplementedError(conf.optim)
    if b3.debug_enabled():

        def comp_fn(x):
            return x
    else:
        comp_fn = th.compile(fullgraph=True, dynamic=False)

    @comp_fn
    def loss_fn(
        batch: th.Tensor,
    ) -> tuple[th.Tensor, dict[str, tuple[th.Tensor, th.Tensor]]]:
        nll = -th.vmap(
            lambda batch: model(batch).diff_logprob(
                [range(batch.shape[-2])], N=conf.train_samples
            )
        )(batch)
        prob = (-nll).exp()
        brier = (prob - 1).pow(2)
        acc = (th.where(prob == 0.5, th.rand_like(prob), prob) >= 0.5).float()
        return nll.mean(), {
            "nll": (nll.mean().detach(), (nll * nll).mean().detach()),
            "brier": (brier.mean().detach(), brier.pow(2).mean().detach()),
            "acc": (acc.mean(), acc.pow(2).mean()),
        }

    logger.info(f"conf: {yaml.safe_dump(conf.__dict__)}")
    test_const = test_ds.constraints
    @comp_fn
    def test_fn(emb: th.Tensor) -> dict[str, tuple[th.Tensor, th.Tensor]]:
        B, N, E = emb.shape

        def f(batch):
            rum: b3.RUM = model(batch)
            cnd_count = th.tensor(0, dtype=th.long)
            obs_cnd_count = th.tensor(0, dtype=th.long)
            N = 100_000
            for _ in range(0, conf.eval_samples, N):
                x = rum.sample(N)
                cond = b3.apply_constraint(test_const[:-1], x)
                obs_and_cond = th.logical_and(cond, b3.apply_constraint(test_const[-1:], x))
                cnd_count = cnd_count + cond.sum()
                obs_cnd_count = obs_cnd_count + obs_and_cond.sum()
            return obs_cnd_count, cnd_count

        co, o = th.vmap(f, randomness="different")(emb)
        prob = th.where(o == 0, 0.5, co / o)
        acc = (th.where(o == 0.5, th.rand_like(prob), prob) >= 0.5).float()
        brier = (prob - 1).pow(2)
        nll = -prob.clamp(0.001, 0.999).log()
        return {
            "acc": (acc.mean(), acc.pow(2).mean()),
            "brier": (brier.mean(), brier.pow(2).mean()),
            "nll": (nll.mean(), nll.pow(2).mean()),
        }

    with contextlib.closing(writer):
        val_res: dict[str, float] = {}

        for epoch in tqdm(range(conf.epochs), total=conf.epochs):
            with th.no_grad():
                model.eval()
                if conf.net_log_full:
                    assert embed is not None
                    rum: b3.RUM = model(embed.to(device))
                    if conf.net_log_viz:
                        fig = rum.viz(keys=mdf["name"] if "name" in mdf.columns else None)
                        fig.savefig(
                            fig_dir / f"full-{epoch:05d}.jpg",
                            bbox_inches="tight",
                            transparent=True,
                        )
                        fig.savefig(
                            fig_dir / "last.pdf", bbox_inches="tight", transparent=True
                        )
                        plt.close(fig)
                    for k, v in rum.csv(keys=mdf["name"] if "name" in mdf.columns else None).items():
                        v.write_csv(csv_dir / f"{epoch:05}-{k}.csv")
                        v.serialize(pl_dir / f"{epoch:05}-{k}.pl")

                val_acc: defaultdict[str, b3.Acc] = defaultdict(b3.Acc)
                for emb in tqdm(
                    it.islice(test_dl, conf.eval_steps),
                    total=conf.eval_steps,
                ):
                    emb = emb.to(device, non_blocking=device.type != "cpu")
                    for k, (v0, v1) in test_fn(emb).items():
                        val_acc[k].push(v0, v1, emb.shape[0])
                logger.info("validation loss: {}", val_acc)
                val_res = {}
                for k, v in val_acc.items():
                    std, err = map(float, v.val)
                    val_res[k] = float(v.v)
                    val_res[f"{k}_std"] = std
                    val_res[f"{k}_err"] = err
                for k, v in val_res.items():
                    writer.add_scalar(f"/val/{k}", v, global_step=epoch)
            if epoch + 1 == conf.epochs:
                break
            train_acc: defaultdict[str, b3.Acc] = defaultdict(b3.Acc)
            model.train()
            with tqdm(total=conf.epoch_steps) as pbar:
                for emb in it.islice(train_dl, conf.epoch_steps):
                    emb = emb.to(device, non_blocking=device.type != "cpu")
                    opt.zero_grad()
                    nll, acc = loss_fn(emb)
                    nll.backward()
                    if conf.grad_norm != 0:
                        norm = nn.utils.clip_grad_norm_(
                            model.parameters(), conf.grad_norm, error_if_nonfinite=True
                        )
                        train_acc["norm"].push(norm, norm * norm, 1)
                    opt.step()
                    for k, (v, v2) in acc.items():
                        train_acc[k].push(v, v2, len(emb))
                    pbar.update()
                logger.info("train loss {}", train_acc)
                for k, v in train_acc.items():
                    std, err = v.val
                    writer.add_scalar(f"/train/{k}", v.v, global_step=epoch)
                    writer.add_scalar(f"/train/{k}_std", std, global_step=epoch)
                    writer.add_scalar(f"/train/{k}_err", err, global_step=epoch)

            th.save(model.state_dict(), output_dir / "models" / f"model-{epoch:05d}.pkl")
    with open(output_dir / "done", "w") as fd:
        json.dump(val_res, fd, indent=2)


if __name__ == "__main__":
    main()
