from typing import (
    Any,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    no_type_check,
)

import numpy as np
import torch
from torch import nn
from einops import rearrange, repeat
from vmoc.feat_attn import SeqFeatAttn

ModuleType = Type[nn.Module]
ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]],
                 Sequence[Dict[Any, Any]]]


def miniblock(input_size: int,
              output_size: int = 0,
              norm_layer: Optional[ModuleType] = None,
              norm_args: Optional[Union[Tuple[Any, ...], Dict[Any,
                                                              Any]]] = None,
              activation: Optional[ModuleType] = None,
              act_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
              linear_layer: Type[nn.Linear] = nn.Linear,
              dropout=None) -> List[nn.Module]:
  """Construct a miniblock with given input/output-size, norm layer and \
    activation."""
  layers: List[nn.Module] = [linear_layer(input_size, output_size)]
  if norm_layer is not None:
    if isinstance(norm_args, tuple):
      layers += [norm_layer(output_size, *norm_args)]  # type: ignore
    elif isinstance(norm_args, dict):
      layers += [norm_layer(output_size, **norm_args)]  # type: ignore
    else:
      layers += [norm_layer(output_size)]  # type: ignore
  if activation is not None:
    if isinstance(act_args, tuple):
      layers += [activation(*act_args)]
    elif isinstance(act_args, dict):
      layers += [activation(**act_args)]
    else:
      layers += [activation()]
  if dropout is not None:
    layers += [nn.Dropout(dropout)]
  return layers


class MLP(nn.Module):
  """Simple MLP backbone.

    Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ...
    * hidden_sizes[-1] * output_dim

    :param int input_dim: dimension of the input vector.
    :param int output_dim: dimension of the output vector. If set to 0, there
        is no final linear layer.
    :param hidden_sizes: shape of MLP passed in as a list, not including
        input_dim and output_dim.
    :param norm_layer: use which normalization before activation, e.g.,
        ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
        You can also pass a list of normalization modules with the same length
        of hidden_sizes, to use different normalization module in different
        layers. Default to no normalization.
    :param activation: which activation to use after each layer, can be both
        the same activation for all layers if passed in nn.Module, or different
        activation for different Modules if passed in a list. Default to
        nn.ReLU.
    :param device: which device to create this model on. Default to None.
    :param linear_layer: use this module as linear layer. Default to nn.Linear.
    :param bool flatten_input: whether to flatten input data. Default to True.
    """

  def __init__(
      self,
      input_dim: int,
      output_dim: int = 0,
      hidden_sizes: Sequence[int] = (),
      norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
      norm_args: Optional[ArgsType] = None,
      activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
      act_args: Optional[ArgsType] = None,
      device: Optional[Union[str, int, torch.device]] = None,
      linear_layer: Type[nn.Linear] = nn.Linear,
      flatten_input: bool = True,
      dropout=None,
  ) -> None:
    super().__init__()
    self.device = device
    if norm_layer:
      if isinstance(norm_layer, list):
        assert len(norm_layer) == len(hidden_sizes)
        norm_layer_list = norm_layer
        if isinstance(norm_args, list):
          assert len(norm_args) == len(hidden_sizes)
          norm_args_list = norm_args
        else:
          norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
      else:
        norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
        norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
    else:
      norm_layer_list = [None] * len(hidden_sizes)
      norm_args_list = [None] * len(hidden_sizes)
    if activation:
      if isinstance(activation, list):
        assert len(activation) == len(hidden_sizes)
        activation_list = activation
        if isinstance(act_args, list):
          assert len(act_args) == len(hidden_sizes)
          act_args_list = act_args
        else:
          act_args_list = [act_args for _ in range(len(hidden_sizes))]
      else:
        activation_list = [activation for _ in range(len(hidden_sizes))]
        act_args_list = [act_args for _ in range(len(hidden_sizes))]
    else:
      activation_list = [None] * len(hidden_sizes)
      act_args_list = [None] * len(hidden_sizes)
    hidden_sizes = [input_dim] + list(hidden_sizes)
    model = []
    for in_dim, out_dim, norm, norm_args, activ, act_args in zip(
        hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list,
        activation_list, act_args_list):
      model += miniblock(in_dim, out_dim, norm, norm_args, activ, act_args,
                         linear_layer, dropout)
    if output_dim > 0:
      model += [linear_layer(hidden_sizes[-1], output_dim)]
    if dropout is not None:
      model += [nn.Dropout(dropout)]
    self.output_dim = output_dim or hidden_sizes[-1]
    self.model = nn.Sequential(*model)
    self.flatten_input = flatten_input

  @no_type_check
  def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
    if self.device is not None:
      obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
    if self.flatten_input:
      obs = obs.flatten(1)
    return self.model(obs)


def log_probs_from_softmax(softmax_values):
  log_values = torch.log(softmax_values)
  log_sum = torch.log(torch.sum(softmax_values, dim=-1, keepdim=True))
  log_probs = log_values - log_sum
  return log_probs


class BaseNet(nn.Module):

  def __init__(self, config):
    super().__init__()
    self.device = config.device

  def as_tensor(self, x):
    if isinstance(x, torch.Tensor):
      return x
    x = np.asarray(x, dtype=float)
    x = torch.tensor(x, device=self.device, dtype=torch.float32)
    return x

  def as_bool_tensor(self, x):
    if isinstance(x, torch.Tensor):
      return x
    x = torch.tensor(x, device=self.device, dtype=bool)
    return x

  def range_tensor(self, end):
    return torch.arange(end).long().to(self.device)


class PositionalEmbedding(nn.Module):

  def __init__(self, dmodel, max_len=10, pos_dropout=0.02, config=None):
    super().__init__()
    self.pos_embed_type = config.pos_embed_type

    if self.pos_embed_type == "sin":
      self.pos_embed = self.create_sin_embedding(max_len, dmodel)
    elif self.pos_embed_type == "learn":
      self.pos_embed = nn.Embedding(max_len, dmodel)
    elif self.pos_embed_type is None:
      self.pos_embed = None
    else:
      raise ValueError(
          f"Pos_Embed_Type {self.pos_embed_type} not recognized. Use 'sin' or 'learn'."
      )

    self.dropout = nn.Dropout(config.pos_dropout)

  def create_sin_embedding(self, max_len, dmodel):
    pos = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, dmodel, 2).float() * (-np.log(10000.0) / dmodel))
    embedding = torch.zeros(max_len, dmodel)
    embedding[:, 0::2] = torch.sin(pos * div_term)
    embedding[:, 1::2] = torch.cos(pos * div_term)
    return embedding

  def forward(self, x):
    out = x  # pos_embed_type = None

    ndim2_flag = False
    if out.ndim == 2:
      ndim2_flag = True
      out = rearrange(out, 'l d -> 1 l d')

    # out: [bsz, seq_len, dmodel]
    if self.pos_embed_type == "sin":
      pos_embed = rearrange(self.pos_embed[:out.size(1), :], 'l d -> 1 l d')
    elif self.pos_embed_type == "learn":  # "learn"
      positions = repeat(
          torch.arange(0, out.size(1)), 'l -> b l', b=out.size(0))
      pos_embed = self.pos_embed(positions)

    if self.pos_embed_type is not None:
      out = out + self.dropout(pos_embed)

    if ndim2_flag:
      out = rearrange(out, '1 l d -> l d')
    return out


class CatSOLayer(BaseNet):
  '''
  Trainable Params:
    'feat_attn' and 'add', CatSOLayer only have layer norm params
    'concat' will have a linear or fnn depends on hidden_sizes
  '''

  def __init__(
      self,
      dmodel,
      hidden_sizes=(128, 64),
      ninputs=2,  # inputs of forward S, O_{t-1}^l, O_t^{l-1}
      concat_type='feat_attn',  # 'feat_attn', 'add', 'concat'
      # using config
      neads=1,
      concat_dropout=0.02,
      res_conn=False,
      config=None):
    # only use config concat_type, dropout and concat_layer_norm
    super().__init__(config)
    self.dropout = nn.Dropout(config.concat_dropout / 2)
    self.norm = nn.LayerNorm(dmodel)

    if concat_type == 'concat':
      self.concat_fnn = MLP(dmodel * ninputs, dmodel, hidden_sizes)
    elif concat_type == 'feat_attn':
      self.feat_attn = SeqFeatAttn(
          dmodel, config.nheads, dropout=config.concat_dropout / 2)

    self.concat_type = concat_type
    self.config = config

  def _soo_attn(self, A, B=None, C=None):
    # S, O_{t-1}^l, O_t^{l-1}: [bsz, dmodel]
    # state must be FFN(state), same size with B and C
    attended_A_to_B = torch.zeros_like(A)
    attended_C_to_B = torch.zeros_like(A)
    attended_A_to_C = torch.zeros_like(A)
    if B is not None:
      attended_A_to_B = self.feat_attn(A, B)
    if C is not None:
      attended_A_to_C = self.feat_attn(A, C)
      if B is not None:
        attended_C_to_B = self.feat_attn(C, B)
      # Combine the attended outputs
    output = (attended_A_to_B + attended_A_to_C + attended_C_to_B) / 3
    return output

  def forward(self,
              state,
              oembed_tm1_oLl=None,
              oembed_t_oLlm1=None,
              concat_type=None):
    # state should be after preprocess net
    # all inputs must have same dimension dmodel
    if concat_type is None:
      concat_type = self.concat_type

    if concat_type == 'feat_attn':
      out = self._soo_attn(state, oembed_tm1_oLl, oembed_t_oLlm1)

    elif concat_type == 'add':
      out = state
      if oembed_tm1_oLl is not None:
        out += oembed_tm1_oLl
      if oembed_t_oLlm1 is not None:
        out += oembed_t_oLlm1

    elif concat_type == 'concat':
      embeds = [state]
      if oembed_tm1_oLl is not None:
        embeds.append(oembed_tm1_oLl)
      if oembed_t_oLlm1 is not None:
        embeds.append(oembed_t_oLlm1)
      out = torch.cat(embeds, dim=-1)
      out = self.concat_fnn(out)

    else:
      raise Exception(f'wrong f{concat_type}')

    if self.config.res_conn:
      out = state + out
    if self.config.concat_layer_norm:
      out = self.norm(out)
    return self.dropout(out)
