from __future__ import annotations

from abc import ABC, abstractmethod

import lightning as pl
import torch as th

from datasets.base import Env
from models.estimators import ScoreEstimator


class OptStrat(ABC):
    support_lazy_fit: bool = False

    env: Env
    n_queries: int

    def __init__(self, env: Env, n_queries: int) -> None:
        super().__init__()
        self.env = env
        self.n_queries = n_queries

    @abstractmethod
    def suggest_next_queries(
        self, score_est: ScoreEstimator, plf: pl.Fabric
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        """Suggest next query point

        Args:
            score_est (ScoreEstimator): score estimator
            plf (pl.Fabric): lightning fabric object

        Returns:
            th.Tensor: (n_queries, n_covs) suggested contexts
            th.Tensor: (n_queries,) suggested actions
            th.Tensor: (n_queries,) expected improvement taking suggested context-action pairs
        """


class RandomOptStrat(OptStrat):
    support_lazy_fit = True

    def suggest_next_queries(
        self, score_est: ScoreEstimator, plf: pl.Fabric
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        # (n, n_covs)
        ctxs: th.Tensor = self.env.get_ctxs(self.n_queries)
        # (n, n_avail_acts, n_act_feats), (n, n_avail_acts)
        xacts, bacts_avail = self.env.get_avail_actions(ctxs)
        # (n, )
        aidxs: th.Tensor = th.randint(0, xacts.shape[1], (len(ctxs),), dtype=th.long)
        # (n, )
        best_acts = th.gather(bacts_avail, dim=1, index=aidxs[:, None])[:, 0]
        # (n, )
        imps: th.Tensor = th.inf * th.ones((len(ctxs),), dtype=th.float32)
        return ctxs, best_acts, imps


class ProfileOptStrat(OptStrat, ABC):
    n_ctxs: int
    ctxs_bsz: int

    def __init__(self, env: Env, n_queries: int, n_ctxs: int, ctxs_bsz: int) -> None:
        super().__init__(env, n_queries)
        self.n_ctxs = n_ctxs
        self.ctxs_bsz = ctxs_bsz

    @th.no_grad()
    def suggest_next_queries(
        self, score_est: ScoreEstimator, plf: pl.Fabric
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        ctxs: th.Tensor = self.env.get_ctxs(self.n_ctxs)
        best_acts_l: list[th.Tensor] = list()
        imps_l: list[th.Tensor] = list()
        for bctxs in th.split(ctxs, self.ctxs_bsz):
            bbest_acts, bimps = self._get_ctxs_improvements(bctxs, score_est, plf)
            best_acts_l.append(bbest_acts)
            imps_l.append(bimps)
        # (self.n_ctxs, )
        acts: th.Tensor = th.cat(best_acts_l, dim=0)
        imps: th.Tensor = th.cat(imps_l, dim=0)
        # sort items according to improvements made
        best_imps, imps_idxs = th.sort(imps, descending=True)
        # choose top n_queries ctx-act pairs
        n_queries: int = min(len(ctxs), self.n_queries)
        best_imps, imps_idxs = best_imps[:n_queries], imps_idxs[:n_queries]
        best_ctxs: th.Tensor = ctxs[imps_idxs]
        best_acts: th.Tensor = acts[imps_idxs]
        return best_ctxs, best_acts, best_imps

    @abstractmethod
    def _get_ctxs_improvements(
        self, bctxs: th.Tensor, score_est: ScoreEstimator, plf: pl.Fabric
    ) -> tuple[th.Tensor, th.Tensor]:
        pass


class MTSPM(ProfileOptStrat):
    def _get_ctxs_improvements(
        self, bctxs: th.Tensor, score_est: ScoreEstimator, plf: pl.Fabric
    ) -> tuple[th.Tensor, th.Tensor]:
        score_est.eval().to(plf.device)
        bsz: int = len(bctxs)
        bxacts, bacts_avail = self.env.get_avail_actions(bctxs)
        n_acts_avail: int = bxacts.shape[1]
        bctxs_: th.Tensor = bctxs[:, None, :].expand(-1, n_acts_avail, -1)
        # (bsz * n_avail_acts, n_covs + n_act_feats)
        binputs: th.Tensor = (
            th.cat((bctxs_, bxacts), dim=2).flatten(0, 1).to(device=plf.device)
        )
        bouts: th.Tensor | tuple[th.Tensor, ...] = score_est(binputs)
        # get posteior mean and one posterior sample
        # (bsz, n_avail_acts)
        bpms: th.Tensor = (
            score_est.get_posterior_mean(bouts)
            .unflatten(0, (bsz, n_acts_avail))
            .to(device="cpu")
        )
        bpsamps: th.Tensor = (
            score_est.sample_posterior(bouts)
            .unflatten(0, (bsz, n_acts_avail))
            .to(device="cpu")
        )
        # compute improvement
        # (bsz, )
        bbest_pms: th.Tensor = th.max(
            # cap the greatest posterior mean value by the greatest observation so far
            th.minimum(
                bpms, th.broadcast_to(th.max(score_est.train_targets), bpms.shape)
            ),
            dim=1,
        )[0]
        bbest_psamps, bbest_aidxs = th.max(bpsamps, dim=1)
        bimps: th.Tensor = bbest_psamps - bbest_pms
        # from actions indices to acts
        bbest_acts: th.Tensor = th.gather(
            bacts_avail, dim=1, index=bbest_aidxs[:, None]
        )[:, 0]
        return bbest_acts, bimps


class TS(OptStrat):
    def suggest_next_queries(
        self, score_est: ScoreEstimator, plf: pl.Fabric
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        ctxs: th.Tensor = self.env.get_ctxs(self.n_queries)
        n: int = len(ctxs)
        xacts, acts_avail = self.env.get_avail_actions(ctxs)
        n_avail_acts: int = xacts.shape[1]
        ctxs_: th.Tensor = ctxs[:, None, :].expand(-1, n_avail_acts, -1)
        # (bsz * n_avail_acts, n_covs + n_act_feats)
        inputs: th.Tensor = (
            th.cat((ctxs_, xacts), dim=2).flatten(0, 1).to(device=plf.device)
        )
        outs: th.Tensor | tuple[th.Tensor, ...] = score_est(inputs)
        # (bsz, n_avail_acts)
        psamps: th.Tensor = (
            score_est.sample_posterior(outs)
            .unflatten(0, (n, n_avail_acts))
            .to(device="cpu")
        )
        # (bsz, )
        psamps, aidxs = th.max(psamps, dim=1)
        imps: th.Tensor = th.inf * th.ones((self.n_queries,), dtype=th.float32)
        best_acts: th.Tensor = th.gather(acts_avail, dim=1, index=aidxs[:, None])[:, 0]
        return ctxs, best_acts, imps
