# %%
from __future__ import annotations

import gc
import os
import traceback
from typing import Optional

import _common
import lightning as pl
import lightning.fabric.loggers as plf_loggers
import numpy as np
import torch as th
import torchmetrics as thm
import tqdm
from _strategies import ModisteOptStrat, OptStrat

from datasets.base import Env, EnvManager
from datasets.classifications.synthetic import SyntheticEnvManager
from datasets.classifications.matlab import (
    MatlabClassificationNWCEnvManager,
)
from models.estimators import ModisteKNNScoreEst, ModisteScoreEstimatorBase


# %%
@th.no_grad()
def eval_env(
    env: Env,
    score_est: ModisteScoreEstimatorBase,
    plf: pl.Fabric,
    eval_bsz: int,
    metrics_func: thm.MetricCollection,
    to_log_additioanl_metrics: bool,
) -> dict[str, float]:
    score_est.eval().to(device=plf.device)
    n_selected_l: list[th.Tensor] = list()
    rewards_l: list[th.Tensor] = list()
    regrets_l: list[th.Tensor] = list()
    mses_rand_l: list[th.Tensor] = list()
    mses_sel_l: list[th.Tensor] = list()
    metrics_func.reset()
    env.reset()
    while env.has_next():
        # (bsz, n_covs)
        bctxs: th.Tensor = env.get_ctxs(eval_bsz)
        # (bsz, n_acts_avail, n_act_feats), (bsz, n_acts_avail)
        bxacts, bacts_avail = env.get_avail_actions(bctxs)
        # shape info
        bsz: int = bxacts.shape[0]
        n_acts_avail: int = bxacts.shape[1]
        n_act_feats: int = bxacts.shape[2]
        # compute score for each ctx-act pair
        # (bsz, n_acts_avail, n_covs)
        bctxs_ = bctxs[:, None, :].expand(-1, n_acts_avail, -1)
        # (bsz, n_acts_avail, n_covs + n_act_feats)
        binputs: th.Tensor = th.cat((bctxs_, bxacts), dim=2).flatten(0, 1)
        # (bsz, n_acts)
        bpms: th.Tensor = th.unflatten(
            score_est(binputs.to(plf.device)),
            dim=0,
            sizes=(bxacts.shape[0], bxacts.shape[1]),
        ).to(device="cpu")
        # choose which action to take
        # (bsz,)
        btargets_est, baidxs = th.max(bpms, dim=1)
        bacts: th.Tensor = th.gather(bacts_avail, dim=1, index=baidxs[:, None])[:, 0]
        # collect rewards
        # (bsz,)
        brewards, binfo = env.compute_rewards(bctxs, bacts)
        bcinds, bexinds = env.decompose_xacts(
            th.gather(
                bxacts, dim=1, index=baidxs[:, None, None].expand(-1, -1, n_act_feats)
            )[:, 0, :]
        )
        bn_selected: th.Tensor = th.sum(bcinds, dim=1)
        bbest_rewards: th.Tensor = env.compute_optimal_rewards(bctxs)
        bregrets: th.Tensor = bbest_rewards - brewards
        # mse of selected actions
        bmses_sel: th.Tensor = th.nn.functional.mse_loss(
            btargets_est, brewards, reduction="none"
        )
        # metrics of current selection
        metrics_func.update(binfo.pyhats[:, :, None], binfo.ys[:, None])
        n_selected_l.append(bn_selected)
        # record metrics
        rewards_l.append(brewards)
        regrets_l.append(bregrets)
        mses_sel_l.append(bmses_sel)
        if to_log_additioanl_metrics:
            # mse of same ctx but random action
            # choose a random action from current available actions
            baidxs = th.randint(0, n_acts_avail, (len(bctxs),), dtype=th.long)
            bacts = th.gather(bacts_avail, dim=1, index=baidxs)[:, 0]
            brewards_rand: th.Tensor = env.compute_rewards(bctxs, bacts)[0]
            bmses_rand: th.Tensor = th.nn.functional.mse_loss(
                th.gather(bpms, dim=1, index=bacts[:, None]).flatten(),
                brewards_rand,
                reduction="none",
            )
            mses_rand_l.append(bmses_rand)
    metrics_d: dict[str, float] = {
        k: v.item() for k, v in metrics_func.compute().items()
    }
    metrics_func.reset()
    # compute average metrics
    reward: th.Tensor = th.mean(th.cat(rewards_l, dim=0))
    regret: th.Tensor = th.mean(th.cat(regrets_l, dim=0))
    mse_sel: th.Tensor = th.mean(th.cat(mses_sel_l, dim=0))
    n_selected: th.Tensor = th.mean(th.cat(n_selected_l, dim=0))
    metrics_d.update(
        {
            "reward": reward.item(),
            "regret": regret.item(),
            "mse_selected": mse_sel.item(),
            "n_selected": n_selected.item(),
        }
    )
    if to_log_additioanl_metrics:
        mse_rand: th.Tensor = th.mean(th.cat(mses_rand_l, dim=0))
        metrics_d.update({"mse": mse_rand.item()})
    return metrics_d


def make_init_queries(
    train_env: Env,
    score_est: ModisteScoreEstimatorBase,
    init_capital: int,
    plf: pl.Fabric,
    rseed: Optional[int] = None,
) -> dict[str, float]:
    score_est.to(plf.device)
    train_env.reset()
    # one round of query
    generator: np.random.Generator | None = (
        np.random.default_rng(rseed) if rseed is not None else None
    )
    # (init_capital, n_covs)
    ctxs: th.Tensor = train_env.get_init_ctxs(init_capital, generator=generator)
    # (init_capital, n_acts_avail, n_act_feats) (init_capital, n_acts_avail)
    xacts, acts_avail = train_env.get_init_avail_actions(ctxs, generator)
    bsz: int = xacts.shape[0]
    n_acts_avail: int = xacts.shape[1]
    n_act_feats: int = xacts.shape[2]
    # (init_capital)
    aidxs: th.Tensor = (
        th.as_tensor(generator.integers(0, n_acts_avail, (bsz,)), dtype=th.long)
        if generator is not None
        else th.randint(0, n_acts_avail, (bsz,), dtype=th.long)
    )
    # (init_capital, )
    acts: th.Tensor = th.gather(acts_avail, dim=1, index=aidxs[:, None])[:, 0]
    # (init_capital, n_act_feats)
    xacts: th.Tensor = th.gather(
        xacts, dim=1, index=aidxs[:, None, None].expand(-1, -1, n_act_feats)
    )[:, 0, :]
    # (init_capital, n_covs + n_act_feats)
    init_inputs: th.Tensor = th.cat((ctxs, xacts), dim=1)
    init_targets, init_infos = train_env.compute_rewards(ctxs, acts)
    score_est.set_train_data_(init_inputs, init_targets, init_infos)
    # fit score estimator
    init_metrics = score_est.fit_(plf)
    init_metrics.update({"n_obs": len(score_est.train_inputs)})
    return init_metrics


def run_trial(
    score_est: ModisteScoreEstimatorBase, strat: OptStrat, plf: pl.Fabric
) -> dict[str, float]:
    score_est.eval().to(device=plf.device)
    strat.env.reset()
    # determine next batch of queries
    bctxs, bacts, bimps = strat.suggest_next_queries(score_est, plf)
    # evaluate suggested ctx-act pair
    bxacts: th.Tensor = strat.env.acts_to_xacts(bacts)
    binputs: th.Tensor = th.cat((bctxs, bxacts), dim=1).to(device=plf.device)
    btargets, binfos = strat.env.compute_rewards(bctxs, bacts)
    btargets = btargets.to(device=plf.device)
    # compute actual improvement
    bpms: th.Tensor = score_est(binputs)
    brimps: th.Tensor = btargets - bpms
    brimps = brimps.to(device="cpu")
    # add queries to estimator
    score_est.add_to_train_data_(binputs, btargets, binfos)
    metrics_d: dict[str, float] = {
        "imp_avg": th.mean(bimps).item(),
        "imp_max": th.max(bimps).item(),
        "imp_min": th.min(bimps).item(),
        "rimp_avg": th.mean(brimps).item(),
        "rimp_max": th.max(brimps).item(),
        "rimp_min": th.min(brimps).item(),
        "n_obs": len(score_est.train_inputs),
    }
    if not strat.support_lazy_fit:
        fit_metrics_d: dict[str, float] = score_est.fit_(plf)
        metrics_d.update(fit_metrics_d)
    return metrics_d


def maximize(
    env_manager: EnvManager,
    score_est: ModisteScoreEstimatorBase,
    strat: OptStrat,
    init_capital: int,
    n_iter: int,
    metrics_func: thm.MetricCollection,
    plf: pl.Fabric,
    eval_bsz: int = 1,
    eval_every_n_iter: int = 1,
    ckpt_p: Optional[str] = None,
    save_ckpt_every_n_iter: int = 1,
    init_capital_rseed: Optional[int] = None,
    to_eval_on_train: bool = False,
    to_log_additional_metrics: bool = False,
    gc_every_n_iter: Optional[int] = None,
):
    # make initial queries
    score_est.to(plf.device)
    # run trials
    pbar = tqdm.trange(n_iter)
    for itr in pbar:
        try:
            metrics_d: dict[str, float]
            if itr == 0:
                metrics_d = make_init_queries(
                    env_manager.train_env,
                    score_est,
                    init_capital,
                    plf,
                    init_capital_rseed,
                )
            else:
                metrics_d = run_trial(score_est=score_est, strat=strat, plf=plf)
            if (
                strat.support_lazy_fit
                and score_est._enable_lazy_fit
                and ((itr + 1) == n_iter or itr % eval_every_n_iter == 0)
                and (itr != 0)
            ):
                metrics_d.update(score_est.fit_(plf))
            plf.log_dict(_common.add_prefix_to_dict(metrics_d, "train"), itr)
            # evaluate
            if itr % eval_every_n_iter == 0 or ((itr + 1) == n_iter):
                if to_eval_on_train:
                    # evaluate on train set
                    metrics_d = eval_env(
                        env=env_manager.train_env,
                        score_est=score_est,
                        metrics_func=metrics_func,
                        plf=plf,
                        eval_bsz=eval_bsz,
                        to_log_additioanl_metrics=to_log_additional_metrics,
                    )
                    plf.log_dict(_common.add_prefix_to_dict(metrics_d, "train"), itr)
                # evaluate on validation set
                metrics_d = eval_env(
                    env=env_manager.val_env,
                    score_est=score_est,
                    metrics_func=metrics_func,
                    plf=plf,
                    eval_bsz=eval_bsz,
                    to_log_additioanl_metrics=to_log_additional_metrics,
                )
                plf.log_dict(_common.add_prefix_to_dict(metrics_d, "val"), itr)
            # save ckpt if needed
            if ckpt_p is not None and (
                (itr % save_ckpt_every_n_iter == 0) or ((itr + 1) == n_iter)
            ):
                plf.save(
                    os.path.join(ckpt_p, f"itr_{itr}.ckpt"), score_est.state_dict()
                )
            if gc_every_n_iter is not None and (itr + 1) % gc_every_n_iter == 0:
                gc.collect()
        except (KeyboardInterrupt, RuntimeError) as e:
            traceback.print_exception(e)
            break
    pbar.close()
    if ckpt_p is not None:
        plf.save(os.path.join(ckpt_p, "itr_end.ckpt"), score_est.state_dict())
    # evaluate on test set
    metrics_d: dict[str, float] = eval_env(
        env=env_manager.test_env,
        score_est=score_est,
        metrics_func=metrics_func,
        plf=plf,
        eval_bsz=eval_bsz,
        to_log_additioanl_metrics=to_log_additional_metrics,
    )
    plf.log_dict(_common.add_prefix_to_dict(metrics_d, "test"), n_iter)
    return


# %%
# skin dataset config
# score estimator conf
n_neighbors: int = 50
# strategy conf
n_queries: int = 1
epsilon: float = 0.3
# train_conf
init_capital: int = 500
n_iter: int = 10000
eval_bsz: int = 4096
init_capital_rseed: int = 42
# experiment conf
eval_every_n_iter: int = 500
to_eval_on_train: bool = False
to_log_additional_metrics: bool = False

# %%
# make environment
env_manager: EnvManager = MatlabClassificationNWCEnvManager("space_shuttle")
# construct score estimator
score_est: ModisteScoreEstimatorBase = ModisteKNNScoreEst(
    n_ctx_covs=env_manager.n_covs,
    n_experts_per_fcomb=env_manager.n_experts_per_fcomb,
    n_neighbors=n_neighbors,
)
# construct optimization strategy
strat: OptStrat = ModisteOptStrat(
    env_manager.train_env, n_queries=n_queries, epsilon=epsilon
)

# %%
# configure logger and ckpt path
output_dir: str = (
    f"outputs/run/{env_manager.__class__.__name__}/{score_est.__class__.__name__}/{strat.__class__.__name__}"
)
os.makedirs(output_dir, exist_ok=True)
tfb_logger = plf_loggers.TensorBoardLogger(root_dir=output_dir, name="", version="")
csv_logger = plf_loggers.CSVLogger(root_dir=output_dir, name="", version="")
# ckpt_p: str = os.path.join(tfb_logger.log_dir, "checkpoints")

# %%
# train selector
plf = pl.Fabric(loggers=[tfb_logger, csv_logger])
if isinstance(env_manager, th.nn.Module):
    env_manager.to(device=plf.device)
n_labels: int = env_manager.n_labels
metrics_func = thm.MetricCollection(
    {
        "acc": thm.Accuracy(task="multiclass", num_classes=n_labels),
        "precision": thm.Precision(task="multiclass", num_classes=n_labels),
        "recall": thm.Recall(task="multiclass", num_classes=n_labels),
        "f1-score": thm.F1Score(task="multiclass", num_classes=n_labels),
        "auroc": thm.AUROC(task="multiclass", num_classes=n_labels),
    }
)
maximize(
    env_manager=env_manager,
    score_est=score_est,
    strat=strat,
    init_capital=init_capital,
    n_iter=n_iter,
    metrics_func=metrics_func,
    plf=plf,
    eval_bsz=eval_bsz,
    eval_every_n_iter=eval_every_n_iter,
    init_capital_rseed=init_capital_rseed,
    to_eval_on_train=to_eval_on_train,
    to_log_additional_metrics=to_log_additional_metrics,
)
# logger flush record and close
tfb_logger.finalize("success")
csv_logger.finalize("success")

# %%
