from typing import (
    Any,
    Dict,
    Sequence,
    Tuple,
    Type,
    Union,
    Optional,
)

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Independent, Normal, Categorical
from einops import rearrange, repeat

from tianshou.data import Batch
from vmoc.base_net import MLP, BaseNet, PositionalEmbedding, log_probs_from_softmax, CatSOLayer
from utils.utils_debug import check_leaf_dims

ModuleType = Type[nn.Module]
ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]],
                 Sequence[Dict[Any, Any]]]

SIGMA_MIN = -20
SIGMA_MAX = 2


#* Option
class OptionNet(nn.Module):

  def __init__(
      self,
      state_shape: Union[int, Sequence[int]],
      action_shape: Union[int, Sequence[int]] = 0,
      hidden_sizes: Sequence[int] = (),
      norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
      norm_args: Optional[ArgsType] = None,
      activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
      act_args: Optional[ArgsType] = None,
      device: Union[str, int, torch.device] = "cpu",
      dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
      linear_layer: Type[nn.Linear] = nn.Linear,
      option_embeds=None,
  ) -> None:
    super().__init__()
    self.device = device
    input_dim = int(np.prod(state_shape))
    action_dim = int(np.prod(action_shape))
    input_dim += action_dim
    output_dim = 0
    self.model = MLP(input_dim, output_dim, hidden_sizes, norm_layer, norm_args,
                     activation, act_args, device, linear_layer)
    self.output_dim = self.model.output_dim

    self.option_embeds = option_embeds

  def forward(
      self,
      obs: Union[np.ndarray, torch.Tensor],
      state: Any = None,
      info: Dict[str, Any] = {},
  ) -> Tuple[torch.Tensor, Any]:
    """Mapping: obs -> flatten (inside MLP)-> logits."""
    logits = self.model(obs)
    return logits, state


class OptionActor(nn.Module):

  def __init__(
      self,
      preprocess_net: nn.Module,
      num_options: Sequence[int],
      hidden_sizes: Sequence[int] = (),
      device: Union[str, int, torch.device] = "cpu",
  ) -> None:
    super().__init__()
    self.device = device
    self.preprocess = preprocess_net
    self.output_dim = int(np.prod(num_options))
    input_dim = preprocess_net.output_dim
    self.last = MLP(
        input_dim,  # type: ignore
        self.output_dim,
        hidden_sizes,
        device=self.device)

  def forward(
      self,
      obs: Union[np.ndarray, torch.Tensor],
      oembed_tm1,
      state: Any = None,
      info: Dict[str, Any] = {},
  ) -> Tuple[torch.Tensor, Any]:
    # P(o_t|s_t, o_{t-1}), Categorical
    obs = torch.as_tensor(
        obs,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    oembed_tm1 = torch.as_tensor(
        oembed_tm1,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    obs = torch.cat([obs, oembed_tm1], dim=1)

    logits, hidden = self.preprocess(obs, state)
    logits = self.last(logits)

    dist = Categorical(logits=logits)
    # act: [bsz] -> [bsz, 1]
    olabel_t = dist.sample()
    olabel_t = rearrange(olabel_t, 'b -> b 1')
    # if self._deterministic_eval and not self.training:
    #   act = logits.argmax(axis=-1)
    # probs: [bsz, num_options]
    pO_t = dist.probs
    # entropy: [bsz] -> [bsz, 1]
    ent_t = dist.entropy()
    ent_t = rearrange(ent_t, 'b -> b 1')

    return (olabel_t, pO_t, ent_t), hidden


#* Somo P(a|s,o) Q(s,a,o) implementation
class TriNet(nn.Module):

  def __init__(
      self,
      state_shape: Union[int, Sequence[int]],
      action_shape: Union[int, Sequence[int]] = 0,
      option_shape: Union[int, Sequence[int]] = 0,
      hidden_sizes: Sequence[int] = (),
      norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
      norm_args: Optional[ArgsType] = None,
      activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
      act_args: Optional[ArgsType] = None,
      device: Union[str, int, torch.device] = "cpu",
      softmax: bool = False,
      concat: bool = False,
      num_atoms: int = 1,
      dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
      linear_layer: Type[nn.Linear] = nn.Linear,
  ) -> None:
    super().__init__()
    self.device = device
    self.softmax = softmax
    self.num_atoms = num_atoms
    input_dim = int(np.prod(state_shape))
    action_dim = int(np.prod(action_shape)) * num_atoms
    if concat:
      input_dim += action_dim + option_shape
    output_dim = 0
    self.model = MLP(input_dim, output_dim, hidden_sizes, norm_layer, norm_args,
                     activation, act_args, device, linear_layer)
    self.output_dim = self.model.output_dim

  def forward(
      self,
      obs: Union[np.ndarray, torch.Tensor],
      state: Any = None,
      info: Dict[str, Any] = {},
  ) -> Tuple[torch.Tensor, Any]:
    logits = self.model(obs)
    return logits, state


class BiActorProb(nn.Module):

  def __init__(self,
               preprocess_net: nn.Module,
               action_shape: Sequence[int],
               option_embeds,
               hidden_sizes: Sequence[int] = (),
               device: Union[str, int, torch.device] = "cpu") -> None:
    super().__init__()
    self.preprocess = preprocess_net
    self.device = device
    self.output_dim = int(np.prod(action_shape))
    input_dim = preprocess_net.output_dim
    self.mu = MLP(
        input_dim,  # type: ignore
        self.output_dim,
        hidden_sizes,
        device=self.device)
    self.sigma = MLP(
        input_dim,  # type: ignore
        self.output_dim,
        hidden_sizes,
        device=self.device)

    self.option_embeds = option_embeds

    self.__eps = np.finfo(np.float32).eps.item()

  def forward(
      self,
      obs: Union[np.ndarray, torch.Tensor],
      oembed_t=None,
      state=None,
      info: Dict[str, Any] = {},
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
    """Mapping: P(s,o) -> logits -> (mu, sigma)."""
    obs = torch.as_tensor(
        obs,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    oembed_t = torch.as_tensor(
        oembed_t,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    obs = torch.cat([obs, oembed_t], dim=1)
    logits, hidden = self.preprocess(obs)

    # logits 2 elm tuple (mean, std): [num_envs, act_dim]
    mu = self.mu(logits)
    sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()

    # Independent(Normal(loc: torch.Size([7, 6]), scale: torch.Size([7, 6])), 1)
    dist = Independent(Normal(mu, sigma), 1)
    # act: [num_envs, act_dim]
    act = dist.rsample()
    # if self._deterministic_eval and not self.training:
    #   act = logits[0]

    # dist.log_prob(act)->[num_envs]
    # log_prob: [num_envs, 1]
    log_prob = dist.log_prob(act).unsqueeze(-1)
    # apply correction for Tanh squashing when computing logprob from Gaussian
    # You can check out the original VMOC paper (arXiv 1801.01290): Eq 21.
    # in appendix C to get some understanding of this equation.
    # squashed_action: [num_envs, act_dim]
    act = torch.tanh(act)
    # log_prob - torch.log((1 - squashed_action.pow(2)) +self.__eps) [num_envs, act_dim]
    # log_prob: [num_envs, 1]
    log_prob = log_prob - torch.log((1 - act.pow(2)) + self.__eps).sum(
        -1, keepdim=True)
    ent_t = rearrange(dist.entropy(), 'b -> b 1')
    return (act, log_prob, ent_t), state


class TriCritic(nn.Module):

  def __init__(
      self,
      preprocess_net: nn.Module,
      hidden_sizes: Sequence[int] = (),
      option_embeds=None,
      device: Union[str, int, torch.device] = "cpu",
      preprocess_net_output_dim: Optional[int] = None,
      linear_layer: Type[nn.Linear] = nn.Linear,
      flatten_input: bool = True,
  ) -> None:
    super().__init__()
    self.device = device
    self.preprocess = preprocess_net
    self.output_dim = 1
    input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
    self.last = MLP(
        input_dim,  # type: ignore
        1,
        hidden_sizes,
        device=self.device,
        linear_layer=linear_layer,
        flatten_input=flatten_input,
    )
    self.option_embeds = option_embeds

  def forward(
      self,
      obs: Union[np.ndarray, torch.Tensor],
      act: Optional[Union[np.ndarray, torch.Tensor]] = None,
      opt=None,
      info: Dict[str, Any] = {},
  ) -> torch.Tensor:
    """Mapping: (s, a, o) -> logits -> Q(s, a, o)."""
    obs = torch.as_tensor(
        obs,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    act = torch.as_tensor(
        act,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    opt = torch.as_tensor(
        opt,
        device=self.device,
        dtype=torch.float32,
    ).flatten(1)
    obs = torch.cat([obs, act, opt], dim=1)
    logits, hidden = self.preprocess(obs)
    logits = self.last(logits)
    return logits
