from __future__ import annotations

from abc import ABC
from typing import Optional

import torch as th
import torch.distributions.utils
import torch.utils.data as th_data

import models.classifiers

from .. import base


class SyntheticEnv(base.Env, ABC):
    n_samps: int
    min_features: int
    max_features: int
    temperature: float

    cact_to_fcomb: list[tuple[int, ...]]
    fcomb_to_cact: dict[tuple[int, ...], int]

    def __init__(
        self,
        n_covs: int,
        n_samps: int,
        min_features: int = 1,
        max_features: Optional[int] = None,
        temperature: float = 8.0,
    ) -> None:
        cinds_f, xacts_f, cact_to_fcomb, fcomb_to_cact = (
            models.classifiers.SubsetFeatureClassifier.make_full_action_features(
                n_covs=n_covs,
                min_features=min_features,
                max_features=max_features,
                n_experts_per_comb=1,
            )
        )
        super().__init__(
            n_covs=n_covs,
            n_experts_per_comb=1,
            cinds_f=cinds_f,
            xacts_f=xacts_f,
            alpha=0.0,
        )
        # shape information
        self.n_samps = n_samps
        # hparam for expert decision boundary
        self.temperature = temperature

    def get_avail_actions(
        self, ctxs: Optional[th.Tensor] = None
    ) -> tuple[th.Tensor, th.Tensor]:
        if ctxs is None:
            return self._xacts_f, th.empty(())
        n: int = ctxs.shape[0]
        n_acts: int = len(self.cact_to_fcomb)
        xacts: th.Tensor = self._xacts_f[None, :, :].expand(n, -1, -1)
        actms: th.Tensor = th.ones((n, n_acts), dtype=th.float32)
        return xacts, actms


class SyntheticTrainEnv(SyntheticEnv):
    _xacts: th.Tensor

    _n_loaded: int
    _curr_data: tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor] | None
    _ctx_to_idx: dict[tuple[float, ...], int] | None
    _dataset: th_data.TensorDataset

    @property
    def dataset(self) -> th_data.TensorDataset:
        return self._dataset

    def __init__(
        self,
        n_samps: int,
        n_covs: int,
        min_features: int = 1,
        max_features: int | None = None,
        temperature: float = 8,
    ) -> None:
        super().__init__(
            n_samps=n_samps,
            n_covs=n_covs,
            min_features=min_features,
            max_features=max_features,
            temperature=temperature,
        )
        # make all action features
        self._n_loaded = 0
        self._curr_data = None
        self._ctx_to_idx = None
        xs, ys, _, _ = generate_synthetic_data(n_samps, self.n_covs, self.temperature)
        self._dataset = th_data.TensorDataset(xs, th.ones_like(xs), ys)

    def get_ctxs(self, n_ctxs: int) -> th.Tensor:
        self._curr_data = generate_synthetic_data(n_ctxs, self.n_covs, self.temperature)
        self._n_loaded = self._n_loaded + n_ctxs
        ctxs = self._curr_data[0]
        self._ctx_to_idx = {tuple(ctx.tolist()): i for i, ctx in enumerate(ctxs)}
        return self._curr_data[0]

    def compute_rewards(
        self, ctxs: th.Tensor, acts: th.Tensor
    ) -> tuple[th.Tensor, base.EnvRewardInfo]:
        assert self._curr_data is not None and self._ctx_to_idx is not None
        ctxs_, ys_, _, _ = self._curr_data
        idxs: th.Tensor = th.as_tensor(
            [self._ctx_to_idx[tuple(ctx.tolist())] for ctx in ctxs], dtype=th.long
        )
        ys: th.Tensor = ys_[idxs]
        xacts: th.Tensor = self.acts_to_xacts(acts)
        pyhats: th.Tensor = true_expert_func(ctxs, xacts, self.temperature)
        rewards: th.Tensor = reward_func(
            ctxs, xacts, ys_[idxs], pyhats=pyhats, temperature=self.temperature
        )
        info = base.EnvRewardInfo(pyhats, ys)
        return rewards, info

    def compute_optimal_rewards(self, ctxs: th.Tensor) -> th.Tensor:
        assert self._curr_data is not None and self._ctx_to_idx is not None
        ctxs_, ys_, _, fms_ = self._curr_data
        idxs: th.Tensor = th.as_tensor(
            [self._ctx_to_idx[tuple(ctx.tolist())] for ctx in ctxs], dtype=th.long
        )
        return reward_func(ctxs, fms_[idxs], ys_[idxs], temperature=self.temperature)

    def has_next(self) -> bool:
        return self._n_loaded < self.n_samps

    def reset(self) -> None:
        self._n_loaded = 0
        self._curr_data = None
        self._ctx_to_idx = None
        return super().reset()


class SyntheticEvalEnv(SyntheticEnv):
    _generator: th.Generator

    _fixed_dataset: tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]

    _ctx_to_idx: dict[tuple[float, ...], int]
    _end_idx: int
    _curr_data: tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor] | None
    _dataset: th_data.TensorDataset

    @property
    def dataset(self) -> th_data.TensorDataset:
        return self._dataset

    def __init__(
        self,
        n_samps: int,
        n_covs: int,
        min_features: int = 1,
        max_features: int | None = None,
        temperature: float = 8,
        random_seed: int = 279,
    ) -> None:
        super().__init__(
            n_samps=n_samps,
            n_covs=n_covs,
            min_features=min_features,
            max_features=max_features,
            temperature=temperature,
        )
        self._generator = th.Generator().manual_seed(random_seed)
        # make a fixed dataset
        self._fixed_dataset = generate_synthetic_data(
            n_samps, n_covs, temperature, generator=self._generator
        )
        # bookeeping for iterating over dataset
        self._ctx_to_idx = dict()
        self._end_idx = 0
        self._curr_data = None
        self._dataset = th_data.TensorDataset(
            self._fixed_dataset[0],
            th.ones_like(self._fixed_dataset[0]),
            self._fixed_dataset[1],
        )

    def get_ctxs(self, n_ctxs: int) -> th.Tensor:
        start_idx: int = self._end_idx
        end_idx: int = min(start_idx + n_ctxs, len(self._fixed_dataset[0]))
        ctxs, ys, pyhats, fms = self._fixed_dataset
        self._ctx_to_idx.clear()
        self._curr_data = (
            ctxs[start_idx:end_idx],
            ys[start_idx:end_idx],
            pyhats[start_idx:end_idx],
            fms[start_idx:end_idx],
        )
        self._ctx_to_idx = {
            tuple(ctx.tolist()): i for i, ctx in enumerate(self._curr_data[0])
        }
        self._end_idx = end_idx
        return self._curr_data[0]

    def compute_rewards(
        self, ctxs: th.Tensor, acts: th.Tensor
    ) -> tuple[th.Tensor, base.EnvRewardInfo]:
        assert self._curr_data is not None
        _, _ys, _, _ = self._curr_data
        idxs: th.Tensor = th.as_tensor(
            [self._ctx_to_idx[tuple(ctx.tolist())] for ctx in ctxs], dtype=th.long
        )
        ys: th.Tensor = _ys[idxs]
        xacts: th.Tensor = self.acts_to_xacts(acts)
        pyhats: th.Tensor = true_expert_func(ctxs, xacts, self.temperature)
        rewards: th.Tensor = reward_func(
            ctxs, xacts, ys, pyhats=pyhats, temperature=self.temperature
        )
        info = base.EnvRewardInfo(pyhats, ys)
        return rewards, info

    def compute_optimal_rewards(self, ctxs: th.Tensor) -> th.Tensor:
        assert self._curr_data is not None
        ctxs, ys, _, fms = self._curr_data
        rewards: th.Tensor = reward_func(ctxs, fms, ys, temperature=self.temperature)
        return rewards

    def has_next(self) -> bool:
        return self._end_idx < len(self._fixed_dataset[0])

    def reset(self) -> None:
        self._curr_data = None
        self._end_idx = 0
        return super().reset()


class SyntheticEnvManager(base.EnvManager):

    _train_env: SyntheticTrainEnv
    _val_env: SyntheticEvalEnv

    _n_covs: int
    _n_labels: int

    def __init__(
        self,
        n_covs: int,
        min_features=1,
        max_features: Optional[int] = None,
        temperature: float = 8.0,
        n_train_samps: int = 500,
        n_val_samps: int = 2000,
        val_random_seed: int = 279,
    ) -> None:
        super().__init__()
        self._n_covs = n_covs
        self._n_labels = 2
        self._train_env = SyntheticTrainEnv(
            n_samps=n_train_samps,
            n_covs=n_covs,
            min_features=min_features,
            max_features=max_features,
            temperature=temperature,
        )
        self._val_env = SyntheticEvalEnv(
            n_samps=n_val_samps,
            n_covs=n_covs,
            min_features=min_features,
            max_features=max_features,
            temperature=temperature,
            random_seed=val_random_seed,
        )

    @property
    def n_covs(self):
        return self._n_covs

    @property
    def n_expers_per_fcomb(self) -> int:
        return self.train_env.n_experts_per_fcomb

    @property
    def n_labels(self):
        return self._n_labels

    @property
    def train_env(self):
        return self._train_env

    @property
    def val_env(self):
        return self._val_env

    @property
    def test_env(self):
        return self._val_env

    def expert_func(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        xacts: th.Tensor = self.train_env.acts_to_xacts(acts)
        return true_expert_func(ctxs, xacts, temperature=self.train_env.temperature)


def generate_synthetic_data(
    n_samps: int,
    n_covs: int,
    temperature: float = 8.0,
    generator: Optional[th.Generator] = None,
) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
    xs: th.Tensor = th.cat(
        (
            th.randn((n_samps, n_covs - 1), dtype=th.float32, generator=generator),
            th.rand((n_samps, 1), dtype=th.float32, generator=generator),
        ),
        dim=1,
    )
    x_ths: th.Tensor = th.linspace(0, 1, n_covs)
    fms: th.Tensor = (x_ths[:-1] < xs[:, -1:]) & (xs[:, -1:] <= x_ths[1:])
    pyhats: th.Tensor = th.sigmoid(temperature * th.sum(xs[:, :-1] * fms, dim=1))
    ys: th.Tensor = (th.rand(n_samps) <= pyhats).to(dtype=th.long)
    pyhats = pyhats[:, None].expand(-1, 2).clone()
    pyhats[:, 0] = 1 - pyhats[:, 1]
    fms_ = th.cat((fms, th.ones(n_samps, 1)), dim=1)
    return xs, ys, pyhats, fms_


def true_expert_func(
    xs: th.Tensor, inds: th.Tensor, temperature: float = 8.0
) -> th.Tensor:
    n: int = xs.shape[0]
    n_covs: int = xs.shape[1]
    x_ths: th.Tensor = th.linspace(0, 1, n_covs)
    fms: th.Tensor = (x_ths[:-1] < xs[:, -1:]) & (xs[:, -1:] <= x_ths[1:])
    pyhats: th.Tensor = th.sigmoid(temperature * th.sum(xs[:, :-1] * fms, dim=1))
    fms_: th.Tensor = th.cat((fms, th.ones((n, 1), dtype=th.bool)), 1).to(
        dtype=th.float32
    )
    has_corrects: th.Tensor = th.sum(inds * fms_, 1) == 2
    n_wrongs: th.Tensor = th.sum(inds * (1 - fms_), dim=1) / (n_covs - 2)
    alphas: th.Tensor = has_corrects * (1 - n_wrongs)
    pyhats_: th.Tensor = pyhats * alphas + 0.5 * (1 - alphas)
    pyhats_ = pyhats_[:, None].expand(-1, 2).clone()
    pyhats_[:, 0] = 1.0 - pyhats_[:, 1]
    return pyhats_


def reward_func(
    xs: th.Tensor,
    inds: th.Tensor,
    ys: th.Tensor,
    pyhats: Optional[th.Tensor] = None,
    temperature: float = 8.0,
) -> th.Tensor:
    if pyhats is None:
        pyhats = true_expert_func(xs, inds, temperature)
    cels: th.Tensor = th.nn.functional.cross_entropy(
        torch.distributions.utils.probs_to_logits(pyhats), ys, reduction="none"
    )
    rewards = -cels
    return rewards


def make_uniform_masks(
    n_samps: int,
    n_covs: int,
    min_features: int = 1,
    max_features: Optional[int] = None,
    generator: Optional[th.Generator] = None,
):
    max_features = n_covs if max_features is None else max_features
    fms: th.Tensor = th.zeros((n_samps, n_covs), dtype=th.float32)
    for i in range(n_samps):
        rfidxs: th.Tensor = th.randperm(n_covs, generator=generator)
        rn_features: int = int(
            th.randint(min_features, max_features + 1, (1,), generator=generator).item()
        )
        fms[i, rfidxs[:rn_features]] = 1.0
    return fms
