import warnings
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from einops import repeat, rearrange
import torch
from torch.distributions import Independent, Normal

from tianshou.data import Batch, ReplayBuffer
from tianshou.policy import BasePolicy
from utils.utils_config import ModelDict


class VMOCPolicy(BasePolicy):

  def __init__(
      self,
      actor_dict: ModelDict,
      critic_dict: ModelDict,
      tau: float = 0.005,
      gamma: float = 0.99,
      reward_normalization: bool = False,
      estimation_step: int = 1,
      exploration_noise=None,
      deterministic_eval: bool = True,
      action_scaling: bool = True,
      action_bound_method: str = "clip",
      option_embeds=None,
      config=None,
      **kwargs: Any,
  ) -> None:
    super().__init__(
        action_scaling=action_scaling,
        action_bound_method=action_bound_method,
        **kwargs)
    assert action_bound_method != "tanh", "tanh mapping is not supported" \
        "in policies where action is used as input of qA , because" \
        "raw action in range (-inf, inf) will cause instability in training"

    # VMOC Params
    self.actor_dict = actor_dict

    self.critic_dict = critic_dict
    self.critic_dict.set_qold_eval()

    self._is_auto_alpha = config.auto_alpha
    if self._is_auto_alpha:
      self.a_P_alpha = config.a_log_alpha.detach().exp()
      self.o_P_alpha = config.o_log_alpha.detach().exp()
    else:
      self.a_P_alpha = config.a_P_alpha
      self.o_P_alpha = config.o_P_alpha
      self.a_Q_alpha = config.a_Q_alpha
      self.o_Q_alpha = config.o_Q_alpha

    self._deterministic_eval = deterministic_eval

    # DDPG Params
    assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
    self.tau = tau
    assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
    self._gamma = gamma
    self._noise = exploration_noise
    # it is only a little difference to use GaussianNoise
    # self.noise = OUNoise()
    self._rew_norm = reward_normalization
    self._n_step = estimation_step

    # SOMO Params
    self.option_embeds = option_embeds
    self._indices = None
    self._buffer = None

    self.config = config

  def train(self, mode: bool = True) -> "VMOCPolicy":
    self.training = mode
    self.actor_dict.set_train(mode)
    self.critic_dict.set_train(mode)
    return self

  def exploration_noise(self, act: Union[np.ndarray, Batch],
                        batch: Batch) -> Union[np.ndarray, Batch]:
    # Called in Collector, only used for rollout
    if self._noise is None:
      return act
    if isinstance(act, np.ndarray):
      return act + self._noise(act.shape)
    warnings.warn("Cannot add exploration noise to non-numpy_array action.")
    return act

  def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                 indices: np.ndarray) -> Batch:
    self._indices = indices
    self._buffer = buffer
    return batch

  def sync_weight(self) -> None:
    for layer in self.critic_dict:
      self.soft_update(self.critic_dict[layer].qO1_old,
                       self.critic_dict[layer].qO1, self.tau)
      self.soft_update(self.critic_dict[layer].qO2_old,
                       self.critic_dict[layer].qO2, self.tau)

  #* Functional Implementations
  def _get_option_embeds(self, olabel_t, detach=False):
    # if not detach:
    #   import random
    #   if random.random() > 0.9:
    #     detach = False
    #   else:
    #     detach = True
    # olabel_t: [bsz, 1]
    olabel_t = rearrange(olabel_t, 'b 1 -> b')
    if not isinstance(olabel_t, torch.Tensor):
      olabel_t = torch.tensor(olabel_t, device=self.config.device)
    oembed_t = self.option_embeds(olabel_t)
    if detach:
      oembed_t = oembed_t.detach()
    return oembed_t

  def _target_qO(self, buffer: ReplayBuffer,
                 indices: np.ndarray) -> torch.Tensor:
    with torch.no_grad():
      # batch.obs.shape [bsz, state_dim]; indices: (bsz, )
      batch = buffer[indices]  # batch.obs: s_{t+n}
      oembed_t = self._get_option_embeds(batch.policy.oL1_result.olabel_t)

      # use sample o_t, NOT new_o_t
      # act_t.shape: [bsz, act_dim]
      (new_a_act_t, new_a_logp_t,
       new_a_ent_t), hidden = self.actor_dict.oL0.probO(
           batch.obs, oembed_t=oembed_t)

      # self.critic_dict.Q(s_t,new_a_t,o_t): [bsz, 1]
      layer_dict = self.critic_dict.oL0
      # target_q.shape: [bsz, 1]
      min_target_q = torch.min(
          layer_dict.qO1_old(batch.obs, new_a_act_t, oembed_t),
          layer_dict.qO2_old(batch.obs, new_a_act_t, oembed_t),
      )
      target_q = min_target_q - self.o_Q_alpha * new_a_logp_t
    return target_q

  def _target_qA(self, buffer: ReplayBuffer,
                 indices: np.ndarray) -> torch.Tensor:
    with torch.no_grad():
      batch = buffer[indices]
      # s_{t+1} and o_t
      obs, initial_state_flags, prev_o = self._get_obs_done_prev_o(
          batch, forward_mode='next')
      # oembed_tm1: [num_envs, dmodel]
      oembed_tm1 = self._get_option_embeds(prev_o)
      # o_{t+1}
      (olabel_t, probO_t,
       ent_t), hidden = self.actor_dict.oL1.probO(obs, oembed_tm1)

      layer_dict = self.critic_dict.oL1
      # p(o_{t+1}) Q(s_{t+1})
      # min_target_q: [bsz, num_options]
      min_target_q = torch.min(layer_dict.qO1_old(obs), layer_dict.qO2_old(obs))
      target_q = (probO_t * min_target_q).sum(
          dim=-1, keepdim=True) + self.a_Q_alpha * ent_t
    return target_q

  def _mse_step(self, current_q, target_q, optim, prio_weight=1.0):
    td = current_q - target_q
    loss = (td.pow(2) * prio_weight).mean()
    optim.zero_grad()
    loss.backward()
    optim.step()
    return td, loss

  # self.critic_dict[f'oL{layer_id}'][f'qO{q_id}_optim']
  def _mse_qA(self, batch: Batch, layer_critic_dict):
    # use batch.act/obs/olabel_t
    # weight: scalar 1.0 before unassigned
    # weight: [bsz, 1] after assigned by (td1+td2)/2
    prio_weight = getattr(batch, "weight", 1.0)
    # target_q: [bsz, 1]
    target_q = batch.policy.oL0_result.target_q

    # current_q: [bsz, 1]
    # qO(batch.obs, batch.act): [bsz, 1]
    oembed_t = self._get_option_embeds(batch.policy.oL1_result.olabel_t)
    current_q1 = layer_critic_dict.qO1(batch.obs, batch.act, oembed_t)
    td1, loss1 = self._mse_step(current_q1, target_q,
                                layer_critic_dict.qO1_optim, prio_weight)
    oembed_t = self._get_option_embeds(batch.policy.oL1_result.olabel_t)
    current_q2 = layer_critic_dict.qO2(batch.obs, batch.act, oembed_t)
    td2, loss2 = self._mse_step(current_q2, target_q,
                                layer_critic_dict.qO2_optim, prio_weight)
    return td1, loss1, td2, loss2

  def _mse_qO(self, batch, layer_critic_dict):
    prio_weight = getattr(batch, "weight", 1.0)
    # target_q: [bsz, 1]
    target_q = batch.policy.oL1_result.target_q

    # qO(batch.obs, batch.act): [bsz, 1]
    # current_q: [bsz, 1]
    olabel_t = batch.policy.oL1_result.olabel_t
    current_q1 = layer_critic_dict.qO1(batch.obs).gather(1, olabel_t)
    td1, loss1 = self._mse_step(current_q1, target_q,
                                layer_critic_dict.qO1_optim, prio_weight)

    current_q2 = layer_critic_dict.qO2(batch.obs).gather(1, olabel_t)
    td2, loss2 = self._mse_step(current_q2, target_q,
                                layer_critic_dict.qO2_optim, prio_weight)
    return td1, loss1, td2, loss2

  def _mutual_info(self, batch):
    bsz = batch.obs.shape[0]
    with torch.no_grad():
      a_logp_list = []
      for i in range(self.config.num_options):
        olabel_t = torch.tensor(i, device=self.config.device).repeat(bsz, 1)
        oembed_t = self._get_option_embeds(olabel_t)
        (act, a_logp_t, a_ent_t), hidden = self.actor_dict.oL0.probO(
            batch.obs, oembed_t=oembed_t)
        a_logp_list.append(a_logp_t.squeeze())

      a_logp_O_t = rearrange(torch.stack(a_logp_list), 'd b ... -> b d ...')
      # Ensure the device is consistent
      a_logp_O_t = a_logp_O_t.to(self.config.device)

      layer_res = batch.policy.oL1_result
      logp_O_t = layer_res.probO_t.log()
      olabel_t = layer_res.olabel_t

      # p(a,O|s,o')=p(a|s,O)p(O|s,o')
      logp_joint_aO_stm1_t = a_logp_O_t + logp_O_t
      # p(a|s,o') = \sum_o p(a|s,o)p(o|s,o')
      log_marg_a_dist = torch.logsumexp(
          logp_joint_aO_stm1_t, dim=-1, keepdim=True)
      # p(O|s,a,o') = p(a,O|s,o') / p(a|s,o')
      log_O_cond_prob = logp_joint_aO_stm1_t - log_marg_a_dist

      # H(O) = \E_{s,a,o'}\sim D -logp(O|s,a,o')
      hO = -log_O_cond_prob.mean(dim=0, keepdim=True).repeat(bsz, 1)
      # H(o) = H(O).gather(olable_t)
      ho = torch.gather(hO, 1, olabel_t)

      # p(o|s,a,o') = gather(olabel_t)
      log_conditional_prob = torch.gather(log_O_cond_prob, 1, olabel_t)
      # H(o|s,a,o') = -p(o|s,a,o')logp(o|s,a,o') [bsz, 1]
      conditional_entropy = -torch.sum(
          torch.exp(log_conditional_prob) * log_conditional_prob,
          dim=-1,
          keepdim=True)

      mutual_info = ho - conditional_entropy
    return mutual_info

  def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
    batch.policy.to_torch(device=self.config.device)
    ## Critic
    # Option Critic
    if self.config.use_MI:
      reward = batch.rew
      MI = self._mutual_info(batch)
      batch.rew = MI
      batch = self.compute_nstep_return(batch, self._buffer, self._indices,
                                        self._target_qO, self._gamma,
                                        self._n_step, self._rew_norm)
      batch.policy.oL1_result.target_q = batch.returns
      batch.rew = reward
    else:
      batch.policy.oL1_result.target_q = self._target_qO(
          self._buffer, self._indices)
    td_O1, qO1_loss, td_O2, qO2_loss = self._mse_qO(batch, self.critic_dict.oL1)

    # Action Critic
    # assign batch.returns
    batch = self.compute_nstep_return(batch, self._buffer, self._indices,
                                      self._target_qA, self._gamma,
                                      self._n_step, self._rew_norm)
    batch.policy.oL0_result.target_q = batch.returns
    td_A1, qA1_loss, td_A2, qA2_loss = self._mse_qA(batch, self.critic_dict.oL0)

    # TODO: bug, this should be assinged to buffer not batch
    # current imp takes no effect
    # weight: [bsz, 1], only used in target_q for qA loss
    # batch.weight = (td_O1 + td_O2 + td_A1 + td_A2) / 4.0  # prio-buffer

    ## Policy
    result = dict()

    # Option Policy
    obs, initial_state_flags, prev_o = self._get_obs_done_prev_o(
        batch, forward_mode='learn')
    # oembed_tm1: [num_envs, dmodel]
    oembed_tm1 = self._get_option_embeds(prev_o)
    # olabel_t: [bsz, 1]; probO_t: [bsz, num_options]; entropy: [bsz, 1]
    (olabel_t, probO_t,
     o_ent_t), hidden = self.actor_dict.oL1.probO(obs, oembed_tm1)

    layer_dict = self.critic_dict.oL1
    with torch.no_grad():
      current_qO1 = layer_dict.qO1(batch.obs)
      current_qO2 = layer_dict.qO2(batch.obs)
      qO = torch.min(current_qO1, current_qO2)
    pO_loss = -(self.o_P_alpha * o_ent_t +
                (probO_t * qO).sum(dim=-1, keepdim=True)).mean()
    self.actor_dict.oL1.probO_optim.zero_grad()
    pO_loss.backward()
    self.actor_dict.oL1.probO_optim.step()

    if self._is_auto_alpha:
      config = self.config
      o_log_prob = -o_ent_t.detach() + config.o_target_entropy
      o_P_alpha, o_P_alpha_loss = self.learn_alpha(o_log_prob,
                                                   config.o_target_entropy,
                                                   config.o_log_alpha,
                                                   config.o_alpha_optim)
      self.o_P_alpha = o_P_alpha

      result["loss/O_P_alpha"] = o_P_alpha_loss.item()
      result["O_P_alpha"] = o_P_alpha.item()

    # Action Policy
    # use sample o_t, NOT new_o_t
    oembed_t = self._get_option_embeds(batch.policy.oL1_result.olabel_t)
    (act, a_logp_t, a_ent_t), hidden = self.actor_dict.oL0.probO(
        batch.obs, oembed_t=oembed_t)
    # current_q1: [bsz, 1]
    current_qA1 = self.critic_dict.oL0.qO1(batch.obs, act, oembed_t)
    current_qA2 = self.critic_dict.oL0.qO2(batch.obs, act, oembed_t)
    qA = torch.min(current_qA1, current_qA2)
    # obs_result.policy.oL0_result.logp_t: [bsz, 1]
    # pA_loss: scalar
    pA_loss = (self.a_P_alpha * a_logp_t - qA).mean()
    self.actor_dict.oL0.probO_optim.zero_grad()
    pA_loss.backward()
    self.actor_dict.oL0.probO_optim.step()

    if self._is_auto_alpha:
      config = self.config
      a_log_prob = a_logp_t.detach() + config.a_target_entropy
      a_P_alpha, a_alpha_loss = self.learn_alpha(a_log_prob,
                                                 config.a_target_entropy,
                                                 config.a_log_alpha,
                                                 config.a_alpha_optim)
      self.a_P_alpha = a_P_alpha

      result["loss/A_P_alpha"] = a_alpha_loss.item()
      result["A_P_alpha"] = a_P_alpha.item()

    self.sync_weight()

    result.update({
        "loss/pO": pO_loss.item(),
        "loss/qO1": qO1_loss.item(),
        "loss/qO2": qO2_loss.item(),
        "loss/pA": pA_loss.item(),
        "loss/qA1": qA1_loss.item(),
        "loss/qA2": qA2_loss.item(),
        "stats/a_ent": a_ent_t.mean().item(),
        "stats/o_ent": o_ent_t.mean().item(),
        "stats/MI": MI.mean().item() if self.config.use_MI else 0,
    })

    return result

  def learn_alpha(self, log_prob, target_ent, log_alpha, alpha_optim):
    # please take a look at issue #258 if you'd like to change this line
    alpha_loss = -(log_alpha * log_prob).mean()
    alpha_optim.zero_grad()
    alpha_loss.backward()
    alpha_optim.step()
    log_alpha = log_alpha.detach().exp().clamp(-0.5, 0.5)
    return log_alpha, alpha_loss

  def _get_obs_done_prev_o(self, batch, forward_mode):
    '''for all init states, use the 0th option'''
    num_envs = batch.obs.shape[0]
    obs = batch.obs
    if forward_mode == 'rollout':
      # in collector.py:reset(), done at t=1 is set to be an empty Batch
      if isinstance(batch.done, Batch):
        # t=1, first execution, no done/terminated/truncated flag and act etc.
        # batch.done will be Batch()
        # but will be np.ndarray at 2cd, 3rd etc.
        initial_state_flags = np.ones(num_envs, dtype=bool)
      else:
        # t>=2, batch.done will be numpy.ndarray
        # in rollout mode, this is from the execution of last step
        initial_state_flags = batch.done

      # Off-policy can have warm up random samples, with `done` but without batch.policy
      # if `batch.done` is missing, then `policy` must also be missing (never sampled once)
      # unlike `done` is Batch at 1st step and np.ndarray at 2cd, 3rd etc.
      # `batch.policy` is always Batch, so can use `is_empty()`
      if batch.policy.is_empty():
        # For init state (t=1), always use the first option (o=0)
        prev_o = np.zeros([num_envs, 1], dtype=int)
      else:
        # in rollout mode, this is from the execution of last step
        prev_o = batch.policy.oL1_result.olabel_t

    elif forward_mode == 'learn' or forward_mode == 'replay_buffer_ot':
      # in computing gradient mode, use data before fed into policy
      # in rollout mode, it recorded `done` at `t+1` timestep
      # (such data are updated after execution of action, see collector.py:collect()
      # self.data.update (line 290, 301-307) ) self.buffer.add (line 328)
      initial_state_flags = batch.policy.prev_done
      prev_o = batch.policy.prev_o
    elif forward_mode == 'next':
      # for calc P(O_{t+1}|S_{t+1},O_t)
      obs = batch.obs_next
      initial_state_flags = batch.done
      prev_o = batch.policy.oL1_result.olabel_t

    if initial_state_flags.ndim == 1:
      initial_state_flags = rearrange(initial_state_flags, 'b -> b 1')
    # has to use [:, 0]; prev_o is torch.Tensor, init is numpy ndarray
    # 2d indexing leads to error
    prev_o[initial_state_flags[:, 0]] = 0
    return obs, initial_state_flags, prev_o

  def forward(  # type: ignore
      self,
      batch: Batch,
      state: Optional[Union[dict, Batch, np.ndarray]] = None,
      forward_mode='rollout',
      **kwargs: Any,
  ) -> Batch:

    # Initialization `obs`, `initial_state_flags` and `prev_o`
    # obs: [num_envs, state_dim]; initial_state_flags, prev_o: [num_envs, 1]
    obs, initial_state_flags, prev_o = self._get_obs_done_prev_o(
        batch, forward_mode)

    # oembed_tm1: [num_envs, dmodel]
    oembed_tm1 = self._get_option_embeds(prev_o)
    # olabel_t: [bsz, 1]; probO_t: [bsz, num_options]; entropy: [bsz, 1]
    (olabel_t, probO_t,
     o_ent_t), hidden = self.actor_dict.oL1.probO(obs, oembed_tm1)

    oembed_t = self._get_option_embeds(olabel_t)
    (act, logp_a_t, a_ent_t), hidden = self.actor_dict.oL0.probO(
        obs, state=state, info=batch.info, oembed_t=oembed_t)

    # Save DOE Specific Results
    result = {
        'oL1_result': {
            'olabel_t': olabel_t,
            'probO_t': probO_t,
            'logp_t': dict(),
            'ent_t': o_ent_t,
        },
        'oL0_result': {
            'olabel_t': act,
            'probO_t': dict(),
            'logp_t': logp_a_t,
            'ent_t': a_ent_t,
        },
    }
    policy_result = dict()
    # general
    policy_result['prev_o'] = prev_o
    policy_result['prev_done'] = initial_state_flags
    policy_result.update(result)

    policy_res = Batch(**policy_result)
    policy_res.to_numpy()
    return Batch(act=act, state=hidden, policy=policy_res)
