from abc import abstractmethod
from collections.abc import Callable
from typing import cast

import numpy as np
import torch

from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.types import (
    BatchWithReturnsProtocol,
    RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy
from tianshou.policy.base import _nstep_return, _gae_return


class MultiObjectivePolicy:

    def __init__(
        self,
        reward_weights=None,
        *args,
        **kwargs,
    ) -> None:

        self.reward_weights = reward_weights

        super().__init__(
            *args,
            **kwargs
        )

    @staticmethod
    def compute_episodic_return(
        batch: RolloutBatchProtocol,
        buffer: ReplayBuffer,
        indices: np.ndarray,
        v_s_: np.ndarray | torch.Tensor | None = None,
        v_s: np.ndarray | torch.Tensor | None = None,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
    ) -> tuple[np.ndarray, np.ndarray]:
        r"""Compute returns over given batch.

        Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
        to calculate q/advantage value of given batch. Returns are calculated as
        advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
        for estimating returns.

        Setting `v_s_` and `v_s` to None (or all zeros) and `gae_lambda` to 1.0 calculates the
        discounted return-to-go/ Monte-Carlo return.

        :param batch: a data batch which contains several episodes of data in
            sequential order. Mind that the end of each finished episode of batch
            should be marked by done flag, unfinished (or collecting) episodes will be
            recognized by buffer.unfinished_index().
        :param buffer: the corresponding replay buffer.
        :param indices: tells the batch's location in buffer, batch is equal
            to buffer[indices].
        :param v_s_: the value function of all next states :math:`V(s')`.
            If None, it will be set to an array of 0.
        :param v_s: the value function of all current states :math:`V(s)`. If None,
            it is set based upon `v_s_` rolled by 1.
        :param gamma: the discount factor, should be in [0, 1].
        :param gae_lambda: the parameter for Generalized Advantage Estimation,
            should be in [0, 1].

        :return: two numpy arrays (returns, advantage) with each shape (bsz, ).
        """
        rew = batch.rew

        rew = np.average(rew, axis=1)

        if v_s_ is None:
            assert np.isclose(gae_lambda, 1.0)
            v_s_ = np.zeros_like(rew)
        else:
            v_s_ = to_numpy(v_s_.flatten())
            v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices)
        v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten())

        end_flag = np.logical_or(batch.terminated, batch.truncated)
        end_flag[np.isin(indices, buffer.unfinished_index())] = True
        advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
        returns = advantage + v_s
        # normalization varies from each policy, so we don't do it here
        return returns, advantage


    @staticmethod
    def compute_nstep_return(
        batch: RolloutBatchProtocol,
        buffer: ReplayBuffer,
        indices: np.ndarray,
        target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
        gamma: float = 0.99,
        n_step: int = 1,
        rew_norm: bool = False,
        reward_weights: np.ndarray | None = None
    ) -> BatchWithReturnsProtocol:
        r"""Compute n-step return for Q-learning targets.

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})

        where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
        :math:`d_t` is the done flag of step :math:`t`.

        :param batch: a data batch, which is equal to buffer[indices].
        :param buffer: the data buffer.
        :param indices: tell batch's location in buffer
        :param function target_q_fn: a function which compute target Q value
            of "obs_next" given data buffer and wanted indices.
        :param gamma: the discount factor, should be in [0, 1].
        :param n_step: the number of estimation step, should be an int greater
            than 0.
        :param rew_norm: normalize the reward to Normal(0, 1).
            TODO: passing True is not supported and will cause an error!
        :return: a Batch. The result will be stored in batch.returns as a
            torch.Tensor with the same shape as target_q_fn's return tensor.
        """
        assert not rew_norm, "Reward normalization in computing n-step returns is unsupported now."
        if len(indices) != len(batch):
            raise ValueError(f"Batch size {len(batch)} and indices size {len(indices)} mismatch.")

        rew = buffer.rew

        # add weights to the rewards
        if reward_weights is None:
            reward_weights = np.array([1 / rew.shape[1]] * rew.shape[1])
        rew = np.average(rew, weights=reward_weights, axis=1)

        bsz = len(indices)
        indices = [indices]
        for _ in range(n_step - 1):
            indices.append(buffer.next(indices[-1]))
        indices = np.stack(indices)
        # terminal indicates buffer indexes nstep after 'indices',
        # and are truncated at the end of each episode
        terminal = indices[-1]
        with torch.no_grad():
            target_q_torch = target_q_fn(buffer, terminal)  # (bsz, ?)
        target_q = to_numpy(target_q_torch.reshape(bsz, -1))
        target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
        end_flag = buffer.done.copy()
        end_flag[buffer.unfinished_index()] = True
        target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step)

        batch.returns = to_torch_as(target_q, target_q_torch)
        if hasattr(batch, "weight"):  # prio buffer update
            batch.weight = to_torch_as(batch.weight, target_q_torch)
        return cast(BatchWithReturnsProtocol, batch)

    @abstractmethod
    def ask_policy(self, obs):
        pass

    @staticmethod
    def prompt_user_choice(options):
        print("Please choose a trade-off (or input an array with the desired trade-off):")
        print(f"0. exit")
        for i, option in enumerate(options, 1):
            print(f"{i}. {option}")

        while True:
            try:
                try:
                    input_ = input("Enter the number of your choice (add starting '>' for custom trade-off): ")
                    choice = int(input_)
                    if 1 <= choice <= len(options):
                        return choice - 1, options[choice - 1]
                    elif choice == 0:
                        exit(0)
                    else:
                        print("Invalid choice. Please enter a number from the list.")
                except ValueError:
                    if not input_.startswith('>'):
                        print(
                            "Invalid input. Please enter the number of your choice (add starting '>' for custom trade-off): ")
                        continue

                    input_ = input_[1:]
                    custom_trade_off = np.array(input_.split(' '), dtype=float)
                    assert custom_trade_off.size == option.size, f"You need to input {option.size} numbers"
                    return None, custom_trade_off
            except Exception:
                pass
