from typing import Optional, List
import numpy as np
import torch
from torch import jit, nn
from torch.nn import functional as F
from env import CONTROL_SUITE_ENVS


# Wraps the input tuple for a function to process a time x batch x features sequence in batch x features (assumes one output)
def bottle(f, x_tuple):
  x_sizes = tuple(map(lambda x: x.size(), x_tuple))
  y = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple, x_sizes)))  # x[0]: x_tuple[i]=posterior, x[1]: x_sizes[i] then concatenate time,batch -> time x batch
  y_size = y.size()
  return y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:])

def bottle_action(f, x_tuple):
  x_sizes = tuple(map(lambda x: x.size(), x_tuple))
  generated_actions, log_pis, deterministic, mean, std = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple, x_sizes)))
  ga_size, lp_size, d_size = generated_actions.size(), log_pis.size(), deterministic.size()
  mean_size, std_size = mean.size(), std.size()
  def tf(y, y_size):
    return y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:])
  return  tf(generated_actions, ga_size), tf(log_pis, lp_size), tf(deterministic, d_size), tf(mean, mean_size), tf(std, std_size)

def bottle_transition(transition, q_actions, x_tuple, generated_obs=None, encoder=None):
  x_sizes = tuple(map(lambda x: x.size(), x_tuple))
  def tf_pre(x, x_size):
    return x.view(x_size[0] * x_size[1], *x_size[2:])
  if generated_obs is not None:
    beliefs2, prior_states2, prior_mean2, prior_std2, posterior_states2, posterior_mean2, posterior_std2 = \
      transition(tf_pre(x_tuple[1], x_sizes[1]), tf_pre(x_tuple[0], x_sizes[0]), 1,
                 tf_pre(q_actions, q_actions.size()).unsqueeze(dim=0), encoder(tf_pre(generated_obs, generated_obs.size())).unsqueeze(dim=0))
    posterior_states2, posterior_mean2, posterior_std2 = posterior_states2.squeeze(dim=0), posterior_mean2.squeeze(dim=0), posterior_std2.squeeze(dim=0)
    po_size, pom_size, pos_size = posterior_states2.size(), posterior_mean2.size(), posterior_std2.size()
  else:
    beliefs2, prior_states2, prior_mean2, prior_std2 = transition(tf_pre(x_tuple[1], x_sizes[1]), tf_pre(x_tuple[0], x_sizes[0]), 1,
                 tf_pre(q_actions, q_actions.size()).unsqueeze(dim=0))
  beliefs2, prior_states2 = beliefs2.squeeze(dim=0), prior_states2.squeeze(dim=0)
  prior_mean2, prior_std2 = prior_mean2.squeeze(dim=0), prior_std2.squeeze(dim=0)
  b_size, pr_size = beliefs2.size(), prior_states2.size()
  prm_size, prs_size = prior_mean2.size(), prior_std2.size()
  def tf(y, y_size):
    return y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:])
  if generated_obs is not None:
    return tf(beliefs2, b_size), tf(prior_states2, pr_size), tf(prior_mean2, prm_size), tf(prior_std2, prs_size), \
         tf(posterior_states2, po_size), tf(posterior_mean2, pom_size), tf(posterior_std2, pos_size)
  else:
    return tf(beliefs2, b_size), tf(prior_states2, pr_size), tf(prior_mean2, prm_size), tf(prior_std2, prs_size)


# Calculate prior and posterior at each time step
class TransitionModel(jit.ScriptModule):
  __constants__ = ['min_std_dev']

  def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.min_std_dev = min_std_dev
    self.fc_embed_state = nn.Linear(state_size + action_size, belief_size)
    self.rnn = nn.GRUCell(belief_size, belief_size)
    self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size)
    self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size)
    self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size)
    self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size)

  # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
  # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
  # t :  0  1  2  3  4  5
  # o :    -X--X--X--X--X-
  # a : -X--X--X--X--X-
  # n : -X--X--X--X--X-
  # pb: -X-
  # ps: -X-
  # b : -x--X--X--X--X--X-
  # s : -x--X--X--X--X--X-
  @jit.script_method
  def forward(self, prev_state:torch.Tensor, prev_belief:torch.Tensor, length:int, actions:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:
    # Create lists for hidden states (cannot use single tensor as buffer because autograd won't work with inplace writes)
    T = length + 1
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T
    beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state
    # Loop over time sequence
    for t in range(T - 1):
      _state = prior_states[t] if observations is None else posterior_states[t]  # Select appropriate previous state
      # _state = _state if nonterminals is None else _state * nonterminals[t]  # Mask if previous transition was terminal
      # Compute belief (deterministic hidden state)
      hidden = self.act_fn(self.fc_embed_state(torch.cat([_state, actions[t]], dim=1)))
      beliefs[t + 1] = self.rnn(hidden, beliefs[t])
      # Compute state prior by applying transition dynamics
      hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1]))
      prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
      prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev
      prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])     
      if observations is not None:
        # Compute state posterior by applying transition dynamics and using current observation
        t_ = t - 1  # Use t_ to deal with different time indexing for observations
        hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1)))  # beliefs[t] (50, 200)
        posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
        posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
        posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])
    # Return new hidden states
    hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]  # [50 Torch(50,200)] -> Torch(50,50,200)
    if observations is not None:
      hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
    return hidden


class DeterministicPolicy(jit.ScriptModule):
  __constants__ = ['action_min', 'action_max']

  def __init__(self, belief_size, state_size, hidden_size, action_size, env, activation_function='relu', action_space=None):
    super().__init__()
    if env in CONTROL_SUITE_ENVS:
      self.action_min = action_space.minimum[0]
      self.action_max = action_space.maximum[0]
    else:
      self.action_min = action_space.low.astype(np.float64)[0]
      self.action_max = action_space.high.astype(np.float64)[0]
    self.act_fn = getattr(F, activation_function)
    self.fc1 = nn.Linear(belief_size + state_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.Linear(hidden_size, action_size)

  @jit.script_method
  def forward(self, belief, state):
    hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1)))
    hidden = self.act_fn(self.fc2(hidden))
    action = self.fc3(hidden)
    action = torch.clamp(action, min=self.action_min, max=self.action_max)
    return action

class GaussianPolicy(jit.ScriptModule):
  def __init__(self, belief_size, state_size, hidden_size, action_size, env, activation_function='relu', action_space=None):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.fc1 = nn.Linear(belief_size + state_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.mean_linear = nn.Linear(hidden_size, action_size)
    self.log_std_linear = nn.Linear(hidden_size, action_size)

    # action rescaling
    if action_space is None:
      self.action_scale = torch.tensor(1.)
      self.action_bias = torch.tensor(0.)
    else:
      if env in CONTROL_SUITE_ENVS:
        differences = action_space.maximum - action_space.minimum
        sums = action_space.maximum + action_space.minimum
      else:
        differences = action_space.high - action_space.low
        sums = action_space.high + action_space.low
      self.action_scale = torch.FloatTensor(differences / 2.)
      self.action_bias = torch.FloatTensor(sums / 2.)

  @jit.script_method
  def forward(self, belief, state):
    hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1)))
    hidden = self.act_fn(self.fc2(hidden))
    mean = self.mean_linear(hidden)
    log_std = self.log_std_linear(hidden)
    log_std = torch.clamp(log_std, min=-20, max=2)
    return mean, log_std

  def sample(self, belief, state):
    mean, log_std = self.forward(belief, state)
    std = log_std.exp()
    normal = torch.distributions.Normal(mean, std)
    x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
    y_t = torch.tanh(x_t)
    action = y_t * self.action_scale + self.action_bias
    log_prob = normal.log_prob(x_t)
    # Enforcing Action Bound
    log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
    log_prob = log_prob.sum(1, keepdim=True)
    deterministic_action = torch.tanh(mean) * self.action_scale + self.action_bias
    return action, log_prob, deterministic_action, mean, std

  def to(self, device):
    self.action_scale = self.action_scale.to(device)
    self.action_bias = self.action_bias.to(device)
    return super(GaussianPolicy, self).to(device)


class SymbolicObservationModel(jit.ScriptModule):
  def __init__(self, observation_size, belief_size, state_size, embedding_size, action_size, activation_function='relu'):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.fc1 = nn.Linear(belief_size + state_size + action_size, embedding_size)
    self.fc2 = nn.Linear(embedding_size, embedding_size)
    self.fc3 = nn.Linear(embedding_size, observation_size)

  @jit.script_method
  def forward(self, belief, state, action):
    hidden = self.act_fn(self.fc1(torch.cat([torch.cat([belief, state], dim=1), action], dim=1)))
    hidden = self.act_fn(self.fc2(hidden))
    observation = self.fc3(hidden)
    return observation


class VisualObservationModel(jit.ScriptModule):
  __constants__ = ['embedding_size']
  
  def __init__(self, belief_size, state_size, embedding_size, action_size, activation_function='relu'):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.embedding_size = embedding_size
    self.fc1 = nn.Linear(belief_size + state_size, embedding_size)
    self.conv1 = nn.ConvTranspose2d(embedding_size, 128, 5, stride=2)
    self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
    self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
    self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)

  @jit.script_method
  def forward(self, belief, state):
    hidden = self.fc1(torch.cat([belief, state], dim=1))  # No nonlinearity here
    hidden = hidden.view(-1, self.embedding_size, 1, 1)
    hidden = self.act_fn(self.conv1(hidden))
    hidden = self.act_fn(self.conv2(hidden))
    hidden = self.act_fn(self.conv3(hidden))
    observation = self.conv4(hidden)
    return observation


def ObservationModel(symbolic, observation_size, belief_size, state_size, embedding_size, action_size, activation_function='relu'):
  if symbolic:
    return SymbolicObservationModel(observation_size, belief_size, state_size, embedding_size, action_size, activation_function)
  else:
    return VisualObservationModel(belief_size, state_size, embedding_size, action_size, activation_function)


class RewardModel(jit.ScriptModule):
  def __init__(self, belief_size, state_size, hidden_size, action_size, activation_function='relu'):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.fc1 = nn.Linear(belief_size + state_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.Linear(hidden_size, 1)

  @jit.script_method
  def forward(self, belief, state):
    hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1)))
    hidden = self.act_fn(self.fc2(hidden))
    reward = self.fc3(hidden).squeeze(dim=1)
    return reward


class SymbolicEncoder(jit.ScriptModule):
  def __init__(self, observation_size, embedding_size, activation_function='relu'):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.fc1 = nn.Linear(observation_size, embedding_size)
    self.fc2 = nn.Linear(embedding_size, embedding_size)
    self.fc3 = nn.Linear(embedding_size, embedding_size)

  @jit.script_method
  def forward(self, observation):
    hidden = self.act_fn(self.fc1(observation))
    hidden = self.act_fn(self.fc2(hidden))
    hidden = self.fc3(hidden)
    return hidden


class VisualEncoder(jit.ScriptModule):
  __constants__ = ['embedding_size']
  
  def __init__(self, embedding_size, activation_function='relu'):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.embedding_size = embedding_size
    self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
    self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
    self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
    self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
    self.fc = nn.Identity() if embedding_size == 1024 else nn.Linear(1024, embedding_size)

  @jit.script_method
  def forward(self, observation):
    hidden = self.act_fn(self.conv1(observation))
    hidden = self.act_fn(self.conv2(hidden))
    hidden = self.act_fn(self.conv3(hidden))
    hidden = self.act_fn(self.conv4(hidden))
    hidden = hidden.view(-1, 1024)
    hidden = self.fc(hidden)  # Identity if embedding size is 1024 else linear projection
    return hidden


def Encoder(symbolic, observation_size, embedding_size, activation_function='relu'):
  if symbolic:
    return SymbolicEncoder(observation_size, embedding_size, activation_function)
  else:
    return VisualEncoder(embedding_size, activation_function)


class GNetwork(jit.ScriptModule):
  def __init__(self, state_size, action_size, hidden_size):
    super().__init__()
    self.linear1 = nn.Linear(state_size, hidden_size)
    self.linear2 = nn.Linear(hidden_size, hidden_size)
    self.linear3 = nn.Linear(hidden_size, 1)

  @jit.script_method
  def forward(self, state):
    x = F.relu(self.linear1(state))
    x = F.relu(self.linear2(x))
    x = self.linear3(x).squeeze(dim=1)
    return x
