from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Final, Optional, TypeVar

import numpy as np
import torch as th
import torch.distributions.utils
import torch.utils.data as th_data
from lightning.fabric.utilities import AttributeDict

import models.classifiers

DatasetType = TypeVar(
    "DatasetType", bound=th_data.Dataset[tuple[th.Tensor, th.Tensor, th.Tensor]]
)


@dataclass
class EnvRewardInfo(AttributeDict):
    pyhats: th.Tensor
    ys: th.Tensor


class Env(ABC):
    n_covs: int
    n_experts_per_fcomb: int
    alpha: float

    _cinds_f: th.Tensor
    _xacts_f: th.Tensor

    def __init__(
        self,
        n_covs: int,
        n_experts_per_comb: int,
        cinds_f: th.Tensor,
        xacts_f: th.Tensor,
        alpha: float,
    ) -> None:
        super().__init__()
        self.n_covs = n_covs
        self.n_experts_per_fcomb = n_experts_per_comb
        self.alpha = alpha
        self._cinds_f = cinds_f
        self._xacts_f = xacts_f

    @property
    @abstractmethod
    def dataset(self) -> th_data.Dataset: ...

    def get_init_avail_actions(
        self, ctxs: Optional[th.Tensor], generator: Optional[np.random.Generator]
    ) -> tuple[th.Tensor, th.Tensor]:
        """get initial available actions

        This method return "`n_acts_avail`" not the entire action space; indices to the first or second dimension might not return the actual action!

        Args:
            ctxs (Optional[th.Tensor], optional): (bsz, n_covs) contexts of interest; if none, just return action space
            generator (Optional[th.Tensor, optional]): numpy random generator use to genreate initial available actions.

        Returns:
            th.Tensor: (bsz, n_acts_avail, n_act_feats) action features; if ctxs is none, return tensor has shape (n_acts_avail, n_act_feats)
            th.Tensor: (bsz, n_acts_avail) the corresponding actions; if ctxs is none, return tensor has shape (n_acts_avail, ).
        """
        return self.get_avail_actions(ctxs)

    @abstractmethod
    def get_avail_actions(
        self, ctxs: Optional[th.Tensor] = None
    ) -> tuple[th.Tensor, th.Tensor]:
        """get available actions

        This method return "`n_acts_avail`" not the entire action space; indices to the first or second dimension might not return the actual action!

        Args:
            ctxs (Optional[th.Tensor], optional): (bsz, n_covs) contexts of interest; if none, just return action space

        Returns:
            th.Tensor: (bsz, n_acts_avail, n_act_feats) action features; if ctxs is none, return tensor has shape (n_acts_avail, n_act_feats)
            th.Tensor: (bsz, n_acts_avail) the corresponding actions; if ctxs is none, return tensor has shape (n_acts_avail, ).
        """

    def acts_to_xacts(self, acts: th.Tensor) -> th.Tensor:
        """actions to action features

        Args:
            acts (th.Tensor): (bsz, )

        Returns:
            th.Tensor: (bsz, n_act_feats) action features
        """
        return self._xacts_f[acts]

    def xacts_to_acts(self, xacts: th.Tensor) -> th.Tensor:
        """action features back to actions

        Args:
            xacts (th.Tensor): (bsz, n_act_feats)

        Returns:
            th.Tensor: (bsz, ) actions
        """
        bsz: int = len(xacts)
        # xacts should be either zeros or one, so as th.long shouldn't cause trouble
        # (bsz, n_acts, n_act_feats)
        xacts_src: th.Tensor = (
            self._xacts_f[None, :, :].expand(bsz, -1, -1).to(dtype=th.long)
        )
        xacts_: th.Tensor = xacts[:, None, :].expand_as(xacts_src).to(dtype=th.long)
        # (bsz, n_acts,)
        acts_l: list[int] = [
            int(th.argwhere(cmp).item()) for cmp in th.all(xacts_src == xacts_, dim=2)
        ]
        acts: th.Tensor = th.as_tensor(acts_l, dtype=th.long)
        return acts

    def decompose_xacts(self, xacts: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
        """decompose action features into context indicators and expert indicators.

        Args:
            xacts (th.Tensor): (bsz, n_act_feats)

        Returns:
            th.Tensor: (bsz, n_covs) context indicators
            th.Tensor: (bsz, n_expert_per_fcomb) expert indicators
        """
        if self.n_experts_per_fcomb == 1:
            cinds: th.Tensor = xacts
            exinds: th.Tensor = th.zeros((len(xacts),), dtype=th.long)
            return cinds, exinds
        cinds: th.Tensor = xacts[:, : self.n_covs]
        exinds: th.Tensor = xacts[:, self.n_covs :]
        return cinds, exinds

    def get_init_ctxs(
        self, n_ctxs: int, generator: Optional[np.random.Generator]
    ) -> th.Tensor:
        """get initial contexts

        Args:
            n_ctxs (int): number of contexts to sample from the environment
            generator (Optional[np.random.Generator], optional): numpy random generator use to genreate initial context.

        Returns:
            th.Tensor: (n_ctxs, n_covs) contexts
        """
        return self.get_ctxs(n_ctxs)

    @abstractmethod
    def get_ctxs(self, n_ctxs: int) -> th.Tensor:
        """get contexts

        Args:
            n_ctxs (int): number of contexts to sample from the environment

        Returns:
            th.Tensor: (n_ctxs, n_covs) contexts
        """

    @abstractmethod
    def compute_rewards(
        self, ctxs: th.Tensor, acts: th.Tensor
    ) -> tuple[th.Tensor, EnvRewardInfo]:
        """compute reward of given ctx-action pair

        Args:
            ctxs (th.Tensor): (n_ctxs, n_covs) contexts
            acts (th.Tensor): (n_ctxs, ) action to be taken

        Returns:
            th.Tensor: (n_ctxs, ) reward of the given context taking selected action.
            EnvRewardInfo: additional information
        """

    @abstractmethod
    def compute_optimal_rewards(self, ctxs: th.Tensor) -> th.Tensor:
        """compute optimal rewards for the given context

        Args:
            ctxs (th.Tensor): (n_ctxs, n_covs) contexts

        Returns:
            th.Tensor: (n_ctxs, ) rewards taking optimal actions
        """

    @abstractmethod
    def has_next(self) -> bool:
        """whether the environment has next contexts to evaluate

        Returns:
            bool: if there are still some contexts left to be evaluated
        """

    def reset(self) -> None:
        """reset the environment"""
        return


class EnvManager(ABC):

    @property
    @abstractmethod
    def n_covs(self) -> int: ...

    @property
    @abstractmethod
    def n_experts_per_fcomb(self) -> int: ...

    @property
    @abstractmethod
    def n_labels(self) -> int: ...

    @property
    @abstractmethod
    def train_env(self) -> Env: ...

    @property
    @abstractmethod
    def val_env(self) -> Env: ...

    @property
    @abstractmethod
    def test_env(self) -> Env: ...

    @abstractmethod
    def expert_func(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor: ...


class TorchDatasetEnv(Env):
    classifier: models.classifiers.SubsetFeatureClassifier
    is_train: Final[bool]
    n_acts_avail: int

    _dataset: th_data.Subset
    _ctx_to_idx: dict[tuple[float], int]

    _n_loaded: int

    _idxs: list[int] | None

    @property
    def dataset(self) -> th_data.Subset:
        return self._dataset

    def __init__(
        self,
        n_covs: int,
        n_experts_per_comb: int,
        cinds_f: th.Tensor,
        xacts_f: th.Tensor,
        dataset: th_data.Subset,
        classifier: models.classifiers.SubsetFeatureClassifier,
        alpha: float,
        is_train: bool,
        n_acts_avail: Optional[int],
    ) -> None:
        super().__init__(
            n_covs=n_covs,
            n_experts_per_comb=n_experts_per_comb,
            cinds_f=cinds_f,
            xacts_f=xacts_f,
            alpha=alpha,
        )
        self._dataset = dataset
        self.classifier = classifier
        self.is_train = is_train
        self._ctx_to_idx = {
            tuple(dataset[i][0].tolist()): i for i in range(len(dataset))
        }
        self._n_loaded = 0
        self._idxs = None
        # if maximum available action is not configured, use all features
        # maximum available action shall not exceed len(self.act_to_comb)
        self.n_acts_avail = (
            len(self._xacts_f)
            if n_acts_avail is None or n_acts_avail > len(self._xacts_f)
            else n_acts_avail
        )

    def get_init_avail_actions(
        self, ctxs: th.Tensor | None, generator: np.random.Generator | None
    ) -> tuple[th.Tensor, th.Tensor]:
        # entire action space is always avaialable; no need to sample subset of action space
        if self.n_acts_avail == len(self._xacts_f):
            return self.get_avail_actions(ctxs)
        # need to sample subset of action with fixed random generator
        if ctxs is None:
            # sample subset of action
            xacts, acts = self._sample_subset_of_actions(1, generator)
            return xacts[0], acts[0]
        xacts, acts = self._sample_subset_of_actions(len(ctxs), generator)
        return xacts, acts

    def get_avail_actions(
        self, ctxs: th.Tensor | None = None
    ) -> tuple[th.Tensor, th.Tensor]:
        # entire action space is always avaialable; no need to sample subset of action space
        if self.n_acts_avail == len(self._xacts_f):
            n_acts: int = len(self._xacts_f)
            if ctxs is None:
                return self._xacts_f, th.arange(n_acts, dtype=th.long)
            else:
                n: int = ctxs.shape[0]
                xacts: th.Tensor = self._xacts_f[None, :, :].expand(n, -1, -1)
                acts: th.Tensor = th.arange(n_acts, dtype=th.long)
                acts = acts[None, :].expand(n, -1)
                return xacts, acts
        # need to sample subset of action
        if ctxs is None:
            # sample subset of action
            xacts, acts = self._sample_subset_of_actions(1)
            return xacts[0], acts[0]
        xacts, acts = self._sample_subset_of_actions(len(ctxs))
        return xacts, acts

    def _sample_subset_of_actions(
        self, bsz: int, generator: Optional[np.random.Generator] = None
    ) -> tuple[th.Tensor, th.Tensor]:
        """Generate a subset of actions

        Args:
            bsz (int): the batch size
            generator (Optional[np.random.Generator], optional): random generator used to generate teh mask. Defaults to None.

        Returns:
            th.Tensor: (bsz, n_acts_avail, n_covs) the action features
            th.Tensor: (bsz, n_acts_avail) the action features
        """
        # (bsz, n_acts_avail,)
        acts: th.Tensor = th.stack(
            [
                th.as_tensor(
                    (
                        np.random.choice(
                            len(self._xacts_f), (self.n_acts_avail,), replace=False
                        )
                        if generator is None
                        else generator.choice(
                            len(self._xacts_f), (self.n_acts_avail,), replace=False
                        )
                    ),
                    dtype=th.long,
                )
                for _ in range(bsz)
            ]
        )
        xacts: th.Tensor = self._xacts_f[acts]
        return xacts, acts

    def _get_verify_ctxs_ys(self, ctxs: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
        assert self._idxs is not None
        idxs: th.Tensor = th.as_tensor(self._idxs, dtype=th.long)
        if not th.allclose(th.stack([self._dataset[i][0] for i in self._idxs]), ctxs):
            idxs = th.as_tensor(
                [self._ctx_to_idx[tuple(ctx.tolist())] for ctx in ctxs], dtype=th.long
            )
        ys: th.Tensor = th.stack([self._dataset[i][2] for i in idxs])
        return ctxs, ys

    def compute_rewards(
        self, ctxs: th.Tensor, acts: th.Tensor
    ) -> tuple[th.Tensor, EnvRewardInfo]:
        assert self._idxs is not None
        ctxs, ys = self._get_verify_ctxs_ys(ctxs)
        pyhats: th.Tensor = self.classifier.predict_proba(ctxs, acts)
        n_labels: int = pyhats.shape[1]
        rewards: th.Tensor = (
            -th.nn.functional.cross_entropy(
                torch.distributions.utils.probs_to_logits(pyhats), ys, reduction="none"
            )
            if n_labels > 2
            else self._custom_bce(pyhats, ys)
        )
        xacts: th.Tensor = self.acts_to_xacts(acts)
        cinds: th.Tensor = self.decompose_xacts(xacts)[0]
        rewards = rewards - self.alpha * th.sum(cinds, dim=1)
        info = EnvRewardInfo(pyhats, ys)
        return rewards, info

    def _custom_bce(self, pyhats: th.Tensor, ys: th.Tensor) -> th.Tensor:
        pyhats = th.clip(pyhats, 1e-6, 1.0 - 1e-6)
        assert th.all(th.greater(pyhats, 0.0))
        assert th.all(th.less(pyhats, 1.0))
        llike: th.Tensor = th.where(
            ys.to(dtype=th.bool), th.log(pyhats[:, 1]), th.log(1 - pyhats[:, 1])
        )
        return llike

    def compute_optimal_rewards(self, ctxs: th.Tensor) -> th.Tensor:
        return th.inf * th.ones((len(ctxs),))

    def get_init_ctxs(
        self, n_ctxs: int, generator: Optional[np.random.Generator]
    ) -> th.Tensor:
        return self._get_ctxs_train(n_ctxs, generator)

    def get_ctxs(self, n_ctxs: int) -> th.Tensor:
        if self.is_train:
            return self._get_ctxs_train(n_ctxs)
        return self._get_ctxs_eval(n_ctxs)

    def has_next(self) -> bool:
        if self.is_train:
            return self._has_next_train()
        return self._has_next_eval()

    def reset(self) -> None:
        if self.is_train:
            self._reset_train()
        else:
            self._reset_eval()
        super().reset()

    # train
    def _get_ctxs_train(
        self, n_ctxs: int, generator: Optional[np.random.Generator] = None
    ) -> th.Tensor:
        self._n_loaded = self._n_loaded + n_ctxs
        self._idxs = (
            generator.integers(
                0, len(self._dataset), (n_ctxs,), dtype=np.int64
            ).tolist()
            if generator is not None
            else th.randint(0, len(self._dataset), (n_ctxs,), dtype=th.long).tolist()
        )
        assert self._idxs is not None
        ctxs: th.Tensor = th.stack([self._dataset[i][0] for i in self._idxs])
        return ctxs

    def _has_next_train(self) -> bool:
        return self._n_loaded < len(self._dataset)

    def _reset_train(self) -> None:
        self._n_loaded = 0
        self._idxs = None

    # eval
    def _get_ctxs_eval(self, n_ctxs: int) -> th.Tensor:
        start_idx: int = self._end_idx
        end_idx: int = min(start_idx + n_ctxs, len(self._dataset))
        # ctxs = th.stack([self._dataset[i][0] for i in range(start_idx, end_idx)])
        self._idxs = list(range(start_idx, end_idx))
        ctxs = th.stack([self._dataset[i][0] for i in self._idxs])
        self._end_idx = end_idx
        return ctxs

    def _has_next_eval(self) -> bool:
        return self._end_idx < len(self._dataset)

    def _reset_eval(self) -> None:
        self._end_idx = 0
        self._idxs = None
