import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F

from utils import helpers as utl

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")


class RNNEncoder(nn.Module):
    def __init__(self,
                 args,
                 # network size
                 layers_before_gru=(),
                 hidden_size=64,  # hidden_size 是 RNN pass through 的隐藏层的大小，latent_dim 是 RNN 每个 step 的 output，latent 应该是会做为 PPO policy & value 的输入
                 layers_after_gru=(),
                 latent_dim=32,
                 # actions, states, rewards
                 action_dim=2,
                 action_embed_dim=10,
                 state_dim=2,
                 state_embed_dim=10,
                 reward_size=1,
                 reward_embed_size=5,
                 ):
        super(RNNEncoder, self).__init__()

        self.args = args
        self.latent_dim = latent_dim
        self.hidden_size = hidden_size
        self.reparameterise = self._sample_gaussian

        # embed action, state, reward
        self.state_encoder = utl.FeatureExtractor(
            state_dim, state_embed_dim, F.relu)
        self.action_encoder = utl.FeatureExtractor(
            action_dim, action_embed_dim, F.relu)
        self.reward_encoder = utl.FeatureExtractor(
            reward_size, reward_embed_size, F.relu)

        # fully connected layers before the recurrent cell
        curr_input_dim = action_embed_dim + state_embed_dim + reward_embed_size
        self.fc_before_gru = nn.ModuleList([])
        for i in range(len(layers_before_gru)):
            self.fc_before_gru.append(
                nn.Linear(curr_input_dim, layers_before_gru[i]))
            curr_input_dim = layers_before_gru[i]

        # recurrent unit
        # TODO: TEST RNN vs GRU vs LSTM
        self.gru = nn.GRU(input_size=curr_input_dim,
                          hidden_size=hidden_size,
                          num_layers=1,
                          )

        for name, param in self.gru.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)

        # fully connected layers after the recurrent cell
        curr_input_dim = hidden_size
        self.fc_after_gru = nn.ModuleList([])
        for i in range(len(layers_after_gru)):
            self.fc_after_gru.append(
                nn.Linear(curr_input_dim, layers_after_gru[i]))
            curr_input_dim = layers_after_gru[i]

        # output layer
        self.fc_mu = nn.Linear(curr_input_dim, latent_dim)
        self.fc_logvar = nn.Linear(curr_input_dim, latent_dim)

    def _sample_gaussian(self, mu, logvar, num=None):
        if num is None:
            # 5_17 a 0.5 here, should calibrate to other parts of the code?
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            # TODO: double check this code, maybe we should use .unsqueeze(0).expand((num, *logvar.shape))
            raise NotImplementedError
            std = torch.exp(0.5 * logvar).repeat(num, 1)
            eps = torch.randn_like(std)
            mu = mu.repeat(num, 1)
            return eps.mul(std).add_(mu)

    # def reset_hidden(self, hidden_state, done):
    #     """ Reset the hidden state where the BAMDP was done (i.e., we get a new task) """
    #     if hidden_state.dim() != done.dim():
    #         if done.dim() == 2:
    #             done = done.unsqueeze(0)
    #         elif done.dim() == 1:
    #             done = done.unsqueeze(0).unsqueeze(2)
    #     hidden_state = hidden_state * (1 - done)
    #     return hidden_state

    def reset_hidden(self, hidden_state, reset_flag):
        # reset if reset_flag is 1
        """ Reset the hidden state where the BAMDP was done (i.e., we get a new task) """
        if hidden_state.dim() != reset_flag.dim():
            if reset_flag.dim() == 2:
                reset_flag = reset_flag.unsqueeze(0)
            elif reset_flag.dim() == 1:
                reset_flag = reset_flag.unsqueeze(0).unsqueeze(2)
        hidden_state = hidden_state * (1 - reset_flag)
        return hidden_state

    def prior(self, batch_size, sample=True):

        # TODO: add option to incorporate the initial state

        # we start out with a hidden state of zero
        hidden_state = torch.zeros(
            (1, batch_size, self.hidden_size), requires_grad=True).to(device)

        h = hidden_state
        # forward through fully connected layers after GRU
        for i in range(len(self.fc_after_gru)):
            h = F.relu(self.fc_after_gru[i](h))

        # outputs
        latent_mean = self.fc_mu(h)
        latent_logvar = self.fc_logvar(h)
        if sample:
            latent_sample = self.reparameterise(latent_mean, latent_logvar)
        else:
            latent_sample = latent_mean

        return latent_sample, latent_mean, latent_logvar, hidden_state

    def forward(self, actions, states, rewards, hidden_state, return_prior, sample=True, detach_every=None, r_t=None):
        """
        Actions, states, rewards should be given in form [sequence_len * batch_size * dim].
        For one-step predictions, sequence_len=1 and hidden_state!=None.
        For feeding in entire trajectories, sequence_len>1 and hidden_state=None.
        In the latter case, we return embeddings of length sequence_len+1 since they include the prior.
        """
        # we do the action-normalisation (the the env bounds) here
        actions = utl.squash_action(actions, self.args)

        # shape should be: sequence_len x batch_size x hidden_size
        actions = actions.reshape((-1, *actions.shape[-2:]))
        states = states.reshape((-1, *states.shape[-2:]))
        rewards = rewards.reshape((-1, *rewards.shape[-2:]))
        if hidden_state is not None:
            # if the sequence_len is one, this will add a dimension at dim 0 (otherwise will be the same)
            hidden_state = hidden_state.reshape((-1, *hidden_state.shape[-2:]))

        if return_prior:
            # if hidden state is none, start with the prior
            prior_sample, prior_mean, prior_logvar, prior_hidden_state = self.prior(
                actions.shape[1])
            hidden_state = prior_hidden_state.clone()

        # extract features for states, actions, rewards
        # 这个是丢给GRU的前一步，就是简单的MLP
        ha = self.action_encoder(actions)
        hs = self.state_encoder(states)
        hr = self.reward_encoder(rewards)
        h = torch.cat((ha, hs, hr), dim=2)

        # forward through fully connected layers before GRU
        for i in range(len(self.fc_before_gru)):
            h = F.relu(self.fc_before_gru[i](h))

        if r_t is None:

            if detach_every is None:
                # GRU cell (output is outputs for each time step, hidden_state is last output)
                output, _ = self.gru(h, hidden_state)
            else:
                output = []
                # hidden_state 这里的 detach every 并不影响 output 的计算结果
                for i in range(int(np.ceil(h.shape[0] / detach_every))):
                    # pytorch caps if we overflow, nice
                    curr_input = h[i*detach_every:i*detach_every+detach_every]
                    curr_output, hidden_state = self.gru(
                        curr_input, hidden_state)
                    output.append(curr_output)
                    # detach hidden state; useful for BPTT when sequences are very long
                    hidden_state = hidden_state.detach()
                output = torch.cat(output, dim=0)
            gru_h = output.clone()

        else:
            # ignore detach every, only consider r_t for now
            # 因为 r_t 在 每个 traj 上变为0的位置并不一样，所以似乎只能每个 traj 单独循环一下
            batch_size = r_t.shape[1]
            output = []
            for batch_id in range(batch_size):
                output_curr_batch = []
                r_t_curr_batch = r_t[:, batch_id, :]
                n_segment = sum(r_t_curr_batch == 0)  # at least one
                loc_reset = torch.where(r_t_curr_batch == 0)[0]
                assert n_segment >= 1
                if n_segment != 1:
                    for i in range(n_segment):
                        # 0,1,2
                        if i == n_segment - 1:
                            curr_input = h[loc_reset[i]:,
                                           batch_id, :].unsqueeze(1)
                        else:
                            curr_input = h[loc_reset[i]:loc_reset[i+1],
                                           batch_id, :].unsqueeze(1)
                        # hidden state Defaults to zeros if not provided.
                        output_curr_segment, hidden_state = self.gru(
                            curr_input)
                        output_curr_batch.append(output_curr_segment)

                        # detach hidden state; useful for BPTT when sequences are very long
                        hidden_state = hidden_state.detach()
                    output_curr_batch = torch.cat(output_curr_batch, dim=0)
                else:
                    # no reset
                    output_curr_batch, hidden_state = self.gru(
                        h[:, batch_id, :].unsqueeze(1))
                output.append(output_curr_batch)
            output = torch.cat(output, dim=1)
            gru_h = output.clone()

        # forward through fully connected layers after GRU
        for i in range(len(self.fc_after_gru)):
            gru_h = F.relu(self.fc_after_gru[i](gru_h))

        # outputs
        latent_mean = self.fc_mu(gru_h)
        latent_logvar = self.fc_logvar(gru_h)
        if sample:
            latent_sample = self.reparameterise(latent_mean, latent_logvar)
        else:
            latent_sample = latent_mean

        if return_prior:
            latent_sample = torch.cat((prior_sample, latent_sample))
            latent_mean = torch.cat((prior_mean, latent_mean))
            latent_logvar = torch.cat((prior_logvar, latent_logvar))
            output = torch.cat((prior_hidden_state, output))

        if latent_mean.shape[0] == 1:
            latent_sample, latent_mean, latent_logvar = latent_sample[
                0], latent_mean[0], latent_logvar[0]

        return latent_sample, latent_mean, latent_logvar, output
