import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import typing as ty

# from torch.distributions.categorical import Categorical, OneHotCategorical
from torch.distributions.one_hot_categorical import OneHotCategorical
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
# from .base import init_linear_weights_by_xavier
from .base import MLP

#
# class CategoricalPolicy(torch.nn.Module):
#
#     def __init__(
#             self,
#             observation_space, action_space,
#             device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#             representation: str = 'tabular',
#             n_hidden: list = [64], # used for NN
#             layer_norm: bool = False,
#         ):
#         super(CategoricalPolicy, self).__init__()
#
#         self.representation = representation
#         self.observation_space = observation_space
#         self.action_space = action_space
#         self.num_actions = action_space.n
#         self.device = device
#
#         if self.representation == 'tabular':
#             self.dim_observation = observation_space.n
#             self.logit_net = torch.nn.Linear(self.dim_observation, self.num_actions)
#         elif self.representation == 'linear':
#             self.dim_observation = observation_space.shape[0]
#             self.logit_net = torch.nn.Linear(self.dim_observation, self.num_actions)
#         elif self.representation == 'NN':
#             self.dim_observation = observation_space.shape[0]
#             # self.logit_net = torch.nn.Sequential(
#             #     torch.nn.Linear(self.dim_observation, n_hidden),
#             #     torch.nn.ReLU(),
#             #     torch.nn.Linear(n_hidden, self.num_actions),
#             # )
#             self.logit_net = MLP(
#                 n_units=[self.dim_observation]+n_hidden+[self.num_actions],
#                 activation=nn.ReLU, layer_norm=layer_norm,
#             )
#
#         else:
#             raise NotImplementedError
#
#         # self.logit_net.apply(lambda x: init_linear_weights_by_xavier(x, 3))
#
#
#     def forward(self, observation: torch.tensor) -> Categorical:
#         batch_size = observation.shape[0]
#
#         if self.representation == 'tabular':
#             observation = torch.nn.functional.one_hot(observation.long(), num_classes=self.dim_observation)
#
#         logits = self.logit_net(observation.float())
#         assert logits.shape == (batch_size, self.num_actions)
#
#         return Categorical(probs=torch.softmax(logits, dim=1))
#
#
#     def act(
#             self,
#             observation: torch.tensor,
#             greedy: bool=False,
#             as_tensor: bool=False, # used if we would like to back probagate through actions.
#             with_log_prob: bool=False,
#             action: torch.tensor=None,
#         ) -> torch.tensor:
#
#         pi = self.forward(observation)
#
#         if greedy:
#             chosen = pi.probs.argmax(dim=1)
#         else:
#             chosen = pi.sample()
#         if not as_tensor:
#             chosen = chosen.cpu().item()
#
#         if with_log_prob:
#             if action is None:
#                 log_prob = pi.log_prob(chosen)
#             else:
#                 log_prob = pi.log_prob(action)
#             return chosen, log_prob
#         else:
#             return chosen
#
#
#     def log_prob(self, observation: torch.tensor, action: torch.tensor=None) -> torch.tensor:
#         pi = self.forward(observation)
#         if action is None:
#             return pi.logits
#         else:
#             return pi.log_prob(action)



class SquashedGaussianPolicy(nn.Module):

    def __init__(
            self,
            observation_space, action_space,
            action_scaler: str='spinningup', # strict, naive
            # use_strict_scaler=False,
            min_action: float=None,
            max_action: float=None,
            device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            representation: str = 'NN',
            n_hidden: list = [64], # used for NN
            layer_norm: bool = False,
            log_std_min: float = -20,
            log_std_max: float = 2,
        ):
        super(SquashedGaussianPolicy, self).__init__()

        self.representation = representation
        self.observation_space = observation_space
        self.action_space = action_space
        self.device = device

        self.dim_observation = observation_space.shape[0]
        self.dim_action = action_space.shape[0]

        self.action_scaler = action_scaler

        if min_action is None:
            min_action = self.action_space.low
        if max_action is None:
            max_action = self.action_space.high
        print(self.action_space)
        print(max_action)
        print(min_action)
        self.action_expand = torch.as_tensor(
            (max_action - min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        self.action_offset = torch.as_tensor(
            (max_action + min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        print(self.action_expand)
        print(self.action_offset)
        print(f'action lower bound = {min_action}')
        print(f'action upper bound = {max_action}')

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        if self.representation == 'linear':
            self.mu_net = torch.nn.Linear(self.dim_observation, self.dim_action)
            self.log_std_net = torch.nn.Linear(self.dim_observation, self.dim_action)

        elif self.representation == 'NN':
            self.encoder = MLP(
                n_units=[self.dim_observation]+n_hidden,
                activation=nn.ReLU, output_activation=nn.ReLU,
                layer_norm=layer_norm,
            )
            self.mu_net = torch.nn.Linear(n_hidden[-1], self.dim_action)
            self.log_std_net = torch.nn.Linear(n_hidden[-1], self.dim_action)

        else:
            raise NotImplementedError


    def squash_and_exp_log_std(self, log_std: torch.tensor) -> torch.tensor:
        LOG_STD_MAX = 2
        LOG_STD_MIN = -5
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats
        return torch.exp(log_std)


    def forward(self, observation: torch.tensor) -> Normal:

        if self.representation == 'linear':
            mu      = self.mu_net(observation.float())
            log_std = self.log_std_net(observation.float())
        elif self.representation == 'NN':
            embed   = self.encoder(observation.float())
            mu      = self.mu_net(embed)
            log_std = self.log_std_net(embed)
        else:
            raise NotImplementedError

        std = self.squash_and_exp_log_std(log_std)

        return Normal(mu, std)


    def act(
            self,
            observation: torch.tensor,
            greedy: bool=False,
            as_tensor: bool=True, # used if we would like to back probagate through actions.
            with_log_prob: bool=False,
            with_pseudo_ent: bool=False,
            # action: torch.tensor=None,
        ) -> torch.tensor:

        pi = self.forward(observation)

        if greedy:
            chosen = pi.loc
        else:
            chosen = pi.rsample()

        if with_pseudo_ent:
            pseudo_ent = self._pseudo_ent(pi, chosen)
            action = self._squash(chosen)
            if not as_tensor:
                action = action.cpu().detach().numpy()
            return action, pseudo_ent

        elif with_log_prob:
            log_prob = self._log_prob(pi, chosen)
            action = self._squash(chosen)
            if not as_tensor:
                action = action.cpu().detach().numpy()
            return action, log_prob

        else:
            action = self._squash(chosen)
            if not as_tensor:
                action = action.cpu().detach().numpy()
            return action


    def log_prob(self, observation: torch.tensor, action: torch.tensor=None) -> torch.tensor:
        pi = self.forward(observation)
        if action is None:
            action = pi.rsample()
        else:
            action = (action - self.action_offset) / self.action_expand
            assert -1. <= action.min()
            assert  1. >= action.max()
            eps = 1e-6
            action = torch.atanh(action * (1 - eps))
            # if self.use_better_scaler or self.use_strict_scaler:
            #     action = (action - self.action_offset) / self.action_expand
            # else:
            #     action = action / self.max_action
        log_prob = self._log_prob(pi, action)
        return log_prob


    def pseudo_ent(self, observation: torch.tensor, action: torch.tensor=None) -> torch.tensor:
        pi = self.forward(observation)
        if action is None:
            action = pi.rsample()
        else:
            action = (action - self.action_offset) / self.action_expand
            assert -1. <= action.min()
            assert  1. >= action.max()
            eps = 1e-6
            action = torch.atanh(action * (1 - eps))
            # if self.use_better_scaler or self.use_strict_scaler:
            #     action = (action - self.action_offset) / self.action_expand
            # else:
            #     action = action / self.max_action
        pseudo_ent = self._pseudo_ent(pi, action)
        return pseudo_ent


    def _squash(self, action: torch.tensor) -> torch.tensor:
        action = torch.tanh(action)
        action = action * self.action_expand + self.action_offset
        # if self.use_better_scaler or self.use_strict_scaler:
        #     action = action * self.action_expand + self.action_offset
        # else:
        #     action = action * self.max_action
        return action



    def _log_prob(self, pi: Normal, before_squash: torch.tensor) -> torch.tensor:
        '''
            Originally based on [arXiv 1801.01290, Appendix C., Eq. 21].
            A numerically-stabler expression is used.
            It can be drived by
                1) plugging in the definition of tanh to Eq.21 and
                2) doing some calculations.
        '''

        if self.action_scaler == 'naive':
            epsilon = 1e-6
            log_prob = pi.log_prob(before_squash)
            # print(f'{before_squash.shape = }')
            # print(f'{log_prob.shape = }')
            squashed = torch.tanh(before_squash)
            # Enforcing Action Bound
            log_prob -= torch.log(self.action_expand * (1 - squashed.pow(2)) + epsilon)
            log_prob = log_prob.sum(axis=1)
            # print(f'log_prob.shape={log_prob.shape}')
        elif self.action_scaler == 'spinningup':
            '''
                Remark:
                in SpinningUp RL,
                scaling is only considered in action selection
            '''
            log_prob = pi.log_prob(before_squash).sum(axis=-1)
            log_prob -= (2*(np.log(2) - before_squash - nn.functional.softplus(-2*before_squash))).sum(axis=1)
            # print(f'log_prob.shape={log_prob.shape}')
        elif self.action_scaler == 'strict':
            log_prob = pi.log_prob(before_squash)
            jacobian = 2 * (np.log(2) - before_squash - nn.functional.softplus(-2*before_squash))
            jacobian += np.log(self.action_expand)
            log_prob -= jacobian
            log_prob = log_prob.sum(axis=1)
        else:
            raise ValueError

        return log_prob


    def _pseudo_ent(self, pi: Normal, before_squash: torch.tensor) -> torch.tensor:
        '''
            Originally based on [arXiv 1801.01290, Appendix C., Eq. 21].
            A numerically-stabler expression is used.
            It can be drived by
                1) plugging in the definition of tanh to Eq.21 and
                2) doing some calculations.
        '''

        if self.action_scaler == 'naive':
            epsilon = 1e-6
            log_prob = pi.entropy()
            squashed = torch.tanh(before_squash)
            # Enforcing Action Bound
            log_prob -= torch.log(self.action_expand * (1 - squashed.pow(2)) + epsilon)
            log_prob = log_prob.sum(axis=1)
        elif self.action_scaler == 'spinningup':
            '''
                Remark:
                in SpinningUp RL,
                scaling is only considered in action selection
            '''
            log_prob = pi.entropy().sum(axis=-1)
            log_prob -= (2*(np.log(2) - pi.loc - nn.functional.softplus(-2*before_squash))).sum(axis=1)
        elif self.action_scaler == 'strict':
            log_prob = pi.entropy()
            jacobian = 2 * (np.log(2) - pi.loc - nn.functional.softplus(-2*before_squash))
            jacobian += np.log(self.action_expand)
            log_prob -= jacobian
            log_prob = log_prob.sum(axis=1)
        else:
            raise ValueError

        return log_prob



class MixturedSquashedGaussianPolicy(SquashedGaussianPolicy):

    def __init__(
            self,
            observation_space, action_space,
            n_mixtures: int = 2,
            single_head: bool = False,
            action_scaler: str = 'spinningup', # strict, naive
            min_action: float=None,
            max_action: float=None,
            device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            representation: str = 'NN',
            n_hidden: list = [64], # used for NN
            layer_norm: bool = False,
            log_std_min: float = -20,
            log_std_max: float = 2,
        ):
        super(MixturedSquashedGaussianPolicy, self).__init__(
            observation_space, action_space,
            min_action=min_action, max_action=max_action,
            device=device, representation=representation, n_hidden=n_hidden,
            layer_norm=layer_norm, log_std_min=log_std_min, log_std_max=log_std_max,
        )

        self.n_mixtures = n_mixtures
        self.single_head = single_head

        if self.representation == 'NN':
            ''' overwrite heads '''
            self.weight_net = torch.nn.Linear(n_hidden[-1], self.n_mixtures)

            if self.single_head:
                self.mu_net = torch.nn.Linear(n_hidden[-1] + self.n_mixtures, self.dim_action)
                self.log_std_net = torch.nn.Linear(n_hidden[-1] + self.n_mixtures, self.dim_action)
                self.mixture_idx = F.one_hot(torch.arange(0, self.n_mixtures, dtype=torch.long, device=self.device))
                print(f'mixture_idx = {self.mixture_idx}')
            else:
                self.mu_net = torch.nn.ModuleList()
                self.log_std_net = torch.nn.ModuleList()
                for k in range(self.n_mixtures):
                    self.mu_net.append(torch.nn.Linear(n_hidden[-1], self.dim_action))
                    self.log_std_net.append(torch.nn.Linear(n_hidden[-1], self.dim_action))

        else:
            raise NotImplementedError


    def forward(self, observation: torch.tensor) -> Normal:
        batch_size = observation.shape[0]
        embed   = self.encoder(observation.float())
        mixer = OneHotCategorical(logits=self.weight_net(embed))
        normals = []
        for k in range(self.n_mixtures):
            if self.single_head:
                mixture_idx = self.mixture_idx[k].unsqueeze(0).repeat(batch_size, 1)
                assert mixture_idx.shape == (batch_size, self.n_mixtures)
                input = torch.cat([embed, mixture_idx], dim=-1)
                mu = self.mu_net(input)
                std = self.squash_and_exp_log_std(self.log_std_net(input))
            else:
                mu = self.mu_net[k](embed)
                std = self.squash_and_exp_log_std(self.log_std_net[k](embed))
            normals.append(Normal(mu, std))
        return mixer, normals


    def rsample(
        self,
        mixer: OneHotCategorical,
        normals: ty.List[Normal],
        greedy: bool=False,
    ) -> torch.tensor:

        """ the sampled action can be backpropagated through a single component """
        batch_size = mixer.probs.shape[0]

        # print(f'{mixer.probs.shape = }')
        # print(f'{mixer.probs = }')
        if greedy:
            max_idx = torch.argmax(mixer.probs, -1, keepdim=True)
            # print(f'{max_idx.shape = }')
            # print(f'{max_idx = }')
            mask = torch.FloatTensor(mixer.probs.shape, device=self.device)
            mask.zero_()
            mask.scatter_(1, max_idx, 1)
        else:
            mask = mixer.sample()
        # print(f'{mask.shape = }')
        # print(f'{mask = }')
        assert mask.shape == (batch_size, self.n_mixtures)

        candidates = []
        for normal in normals:
            if greedy:
                candidates.append(normal.loc)
            else:
                candidates.append(normal.rsample())
        candidates = torch.stack(candidates, dim=-1)
        assert candidates.shape == (batch_size, self.dim_action, self.n_mixtures)
        # print(f'{candidates.shape = }')
        # print(f'{candidates = }')

        chosen = candidates * mask.unsqueeze(1)
        # print(f'{chosen.shape = }')
        # print(f'{chosen = }')
        chosen = chosen.sum(dim=-1)
        # print(f'{chosen.shape = }')
        # print(f'{chosen = }')
        assert chosen.shape == (batch_size, self.dim_action)

        return chosen


    def act(
        self,
        observation: torch.tensor,
        greedy: bool=False,
        as_tensor: bool=True, # used if we need to back probagate through actions.
        with_log_prob: bool=False,
    ) -> torch.tensor:

        batch_size = observation.shape[0]

        mixer, normals = self.forward(observation)
        chosen = self.rsample(mixer, normals, greedy)

        if with_log_prob:
            log_prob = self._log_prob(mixer, normals, before_squash=chosen)
            action = self._squash(chosen)
            if not as_tensor:
                action = action.cpu().detach().numpy()
            return action, log_prob

        else:
            action = self._squash(chosen)
            if not as_tensor:
                action = action.cpu().detach().numpy()
            return action


    def log_prob(self, observation: torch.tensor, action: torch.tensor=None) -> torch.tensor:
        mixer, normals = self.forward(observation)
        if action is None:
            action = self.rsample(mixer, normals, greedy=False)
        else:
            action = (action - self.action_offset) / self.action_expand
        log_prob = self._log_prob(mixer, normals, before_squash=action)
        return log_prob


    def _log_prob_component(self, normal: Normal, before_squash: torch.tensor) -> torch.tensor:
        return super()._log_prob(normal, before_squash)


    def _log_prob(
        self,
        mixer: OneHotCategorical,
        normals: ty.List[Normal],
        before_squash: torch.tensor
    ) -> torch.tensor:
        '''
            Originally based on [arXiv 1801.01290, Appendix C., Eq. 21].
            A numerically-stabler expression is used.
            It can be drived by
                1) plugging in the definition of tanh to Eq.21 and
                2) doing some calculations.
        '''
        weight = mixer.probs
        batch_size = weight.shape[0]

        log_prob = []
        for normal in normals:
            log_prob.append(self._log_prob_component(normal, before_squash))
        log_prob = torch.stack(log_prob, dim=-1)
        assert log_prob.shape == (batch_size, self.n_mixtures)
        assert log_prob.shape == weight.shape

        return torch.logsumexp(log_prob + torch.log(weight), dim=1, keepdim=False)



class AmortizedPolicy(nn.Module):

    def __init__(
        self,
        observation_space,
        action_space,
        n_hidden: list = [256, 256],
        activation=nn.ReLU,
        layer_norm: bool = False,
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        super().__init__()

        self.observation_space = observation_space
        self.action_space = action_space
        self.dim_observation = observation_space.shape[0]
        self.dim_action = action_space.shape[0]

        self.n_hidden = n_hidden
        self.activation = activation
        self.device = device

        """
        Input: obserbvation, noise (same dim as action)
        Output: action
        """
        self.sampler = MLP(
            n_units=[self.dim_observation + self.dim_action] + n_hidden + [self.dim_action],
            # n_units=[self.dim_observation + self.dim_action] + n_hidden + [n_hidden[0]] + [self.dim_action],
            activation=activation, output_activation=nn.Identity,
            layer_norm=layer_norm,
        )

        min_action = action_space.low
        max_action = action_space.high
        print(self.action_space)
        print(max_action)
        print(min_action)
        self.action_expand = torch.as_tensor(
            (max_action - min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        self.action_offset = torch.as_tensor(
            (max_action + min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        print(self.action_expand)
        print(self.action_offset)
        print(f'action lower bound = {min_action}')
        print(f'action upper bound = {max_action}')


    def _forward(
        self,
        observation: torch.tensor,
        n_particles: int = 1,
        with_log_prob: bool = False,
        use_importance_sampling: bool = False,
    ) -> torch.tensor:
        batch_size = observation.shape[0]
        observation = observation.float()

        if n_particles > 1:
            observation = observation.unsqueeze(1)
            observation = observation.repeat(1, n_particles, 1)
            assert observation.dim() == 3
            noise_shape = (batch_size, n_particles, self.dim_action)
        else:
            noise_shape = (batch_size, self.dim_action)

        if with_log_prob:
            observation.requires_grad = True
            # print(observation.requires_grad)

        """ noise \in [-2, 2) """
        noise = torch.rand(noise_shape).to(self.device) * 4 - 2
        noise.requires_grad = True
        log_p_noise = - self.dim_action * torch.log(torch.tensor([4], dtype=torch.float32, device=self.device))

        samples = self.sampler(torch.cat([observation, noise], dim=-1))
        samples = torch.tanh(samples) if batch_size > 1 else torch.tanh(samples).squeeze(0)
        samples = samples * self.action_expand + self.action_offset

        if with_log_prob:
            if use_importance_sampling:
                # samples.requires_grad = True
                # ob = observation.float().detach().clone()
                # xi = noise.detach().clone()
                # ob.requires_grad = True
                # xi.requires_grad = True
                # jacobian = torch.autograd.functional.jacobian(self.sampler, torch.cat([ob, xi], dim=-1))
                # input = torch.cat([
                #     observation.float().view(batch_size * n_particles, self.dim_observation),
                #     noise.view(batch_size * n_particles, self.dim_action)
                # ], dim=-1)
                jacobian = torch.zeros((batch_size, n_particles, self.dim_action, self.dim_action))   # loop will fill in Jacobian
                for i in range(self.dim_action):
                    grd = torch.zeros(noise_shape)   # same shape as preds
                    grd[:, :, i] = 1    # column of Jacobian to compute
                    # print(noise.requires_grad)
                    # print(observation.requires_grad)
                    # print(grd.requires_grad)
                    # print(samples.requires_grad)
                    samples.backward(gradient=grd, retain_graph=True)
                    # print(f'{noise.grad.shape = }')
                    jacobian[:,:,:,i] = noise.grad   # fill in one column of Jacobian
                    noise.grad.zero_()   # .backward() accumulates gradients, so reset to zero
                # print(f'{input.shape = }')
                # jacobian = torch.autograd.functional.jacobian(self.sampler, input)
                # print(f'{observation.shape = }')
                # print(f'{self.dim_observation = }')
                # print(f'{noise.shape = }')
                # print(f'{samples.shape = }')
                # print(f'{self.dim_action = }')
                # print(f'{jacobian.shape = }')
                eps = 1e-6
                _, logabsdet = torch.linalg.slogdet(jacobian+eps)
                # print(f'{logabsdet.shape = }')
                # print(f'{log_p_noise.shape = }')
                # assert log_p_noise.shape == logabsdet.shape
                assert samples.shape[:-1] == logabsdet.shape
                log_pi = log_p_noise - logabsdet

            else:
                log_pi = torch.zeros((batch_size, n_particles), dtype=torch.float32, device=self.device)

            return samples, log_pi
        else:
            return samples


    def forward(
        self,
        observation: torch.tensor,
        n_particles: int = 1,
        with_log_prob: bool = False,
        use_importance_sampling: bool = False,
    ) -> torch.tensor:
        return self._forward(
            observation, n_particles=n_particles, with_log_prob=with_log_prob,
            use_importance_sampling=use_importance_sampling)


    def act(
        self,
        observation: torch.tensor,
        n_particles: int = 1,
        with_log_prob: bool = False,
    ) -> np.ndarray:
        if observation.dim() == 1:
            observation = observation.unsqueeze(0)
        with torch.no_grad():
            return self._forward(observation, n_particles).cpu().detach().numpy()






class NN(nn.Module):
    def __init__(self, n_input, n_hidden):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(n_input, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)

    def forward(self, x):
        hidden = F.relu(self.fc2(F.relu(self.fc1(x))))
        return F.relu(self.fc2(F.relu(self.fc1(x))))


class StateAwareScaleTranslationNN(nn.Module):
    def __init__(self, n_input, n_output, n_state_embedding, n_hidden):
        """
            s,t: R^{d} -> R^{D-d}
        """
        super(StateAwareScaleTranslationNN, self).__init__()
        self.fc1 = nn.Linear(n_input + n_state_embedding, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3_s = nn.Linear(n_hidden, n_output)
        self.fc3_t = nn.Linear(n_hidden, n_output)

    def forward(self, x, emb):
        hidden = F.relu(self.fc2(F.relu(self.fc1(torch.cat([x, emb], dim=-1)))))
        s = torch.tanh(self.fc3_s(hidden))
        t = self.fc3_t(hidden)
        return s, t


class StateAwareRealNVP(nn.Module):
    def __init__(self, n_flows, data_dim, state_dim, n_hidden, n_emb):
        super(StateAwareRealNVP, self).__init__()
        self.n_flows = n_flows
        self.st_net = torch.nn.ModuleList()
        self.embedder = NN(state_dim, n_emb)

        self.data_dim = data_dim
        self.n_identity = int(data_dim/2)
        print(f'(d,D)={self.n_identity,data_dim}')

        for k in range(n_flows):
            """
                s,t: R^{d} -> R^{D-d}
            """
            self.st_net.append(StateAwareScaleTranslationNN(self.n_identity, data_dim - self.n_identity, n_emb, n_hidden))


    def forward(self, x, state, n_layers=None):
        if n_layers == None:
            n_layers = self.n_flows

        emb = self.embedder(state)

        log_det_jacobian = 0
        for k in range(n_layers):
            # print(f'{x.shape = }')
            # print(f'{k = }')
            if k%2==0:
                x_id = x[:, :self.n_identity]
                x_st = x[:, self.n_identity:]
            else:
                if self.data_dim%2 == 0:
                    n_identity = self.n_identity
                else:
                    n_identity = self.n_identity + 1
                x_st = x[:, :n_identity]
                x_id = x[:, n_identity:]
            # print(f'{x_id.shape = }')
            # print(f'{x_st.shape = }')

            s, t = self.st_net[k](x_id, emb)
            # print(f'{s.shape = }')
            # print(f'{t.shape = }')
            x_st = torch.exp(s)*x_st + t

            if k%2==0:
                x = torch.cat([x_id, x_st], dim=1)
            else:
                x = torch.cat([x_st, x_id], dim=1)

            log_det_jacobian += s.sum(dim=1, keepdim=True)

        return x, log_det_jacobian


    def inverse(self, z, state, n_layers=None):
        if n_layers == None:
            n_layers = self.n_flows

        emb = self.embedder(state)

        for k in reversed(range(n_layers)):
            if k%2==0:
                z_id = z[:, :self.n_identity]
                z_st = z[:, self.n_identity:]
            else:
                # z_st = z[:, :self.n_identity+1]
                # z_id = z[:, self.n_identity+1:]
                if self.data_dim%2 == 0:
                    n_identity = self.n_identity
                else:
                    n_identity = self.n_identity + 1
                z_st = z[:, :n_identity]
                z_id = z[:, n_identity:]

            s, t = self.st_net[k](z_id, emb)
            z_st = (z_st - t) / torch.exp(s)

            if k%2==0:
                z = torch.cat([z_id, z_st], dim=1)
            else:
                z = torch.cat([z_st, z_id], dim=1)

        return z



class RealNVPPolicy(nn.Module):

    def __init__(
        self,
        observation_space,
        action_space,
        n_hidden: list = [256, 256],
        activation=nn.ReLU,
        layer_norm: bool = False,
        noise_type: str = 'uniform',
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        super().__init__()

        self.observation_space = observation_space
        self.action_space = action_space
        self.dim_observation = observation_space.shape[0]
        self.dim_action = action_space.shape[0]

        self.noise_type = noise_type
        self.n_hidden = n_hidden
        self.activation = activation
        self.device = device

        """
        Input: obserbvation, noise (same dim as action)
        Output: action
        """
        self.sampler = StateAwareRealNVP(
            n_flows=len(n_hidden), data_dim=self.dim_action, state_dim=self.dim_observation,
            n_hidden=n_hidden[-1], n_emb=n_hidden[-1],
        )

        if self.noise_type == 'uniform':
            """
                noise \in [-1, 1): [0, 1) -> [0, 2) -> [-1, 1)
            """
            # noise = torch.rand(noise_shape).to(self.device) * 2 - 1
            self.log_p_noise = - self.dim_action * torch.log(torch.tensor([2], dtype=torch.float32, device=self.device))
        elif self.noise_type == 'gaussian':
            self.noise_generator = MultivariateNormal(
                torch.zeros(self.dim_action, dtype=torch.float32, device=self.device),
                torch.eye(self.dim_action, dtype=torch.float32, device=self.device))

        min_action = action_space.low
        max_action = action_space.high
        print(self.action_space)
        print(max_action)
        print(min_action)
        self.action_expand = torch.as_tensor(
            (max_action - min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        self.action_offset = torch.as_tensor(
            (max_action + min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        print(self.action_expand)
        print(self.action_offset)
        print(f'action lower bound = {min_action}')
        print(f'action upper bound = {max_action}')


    # def inverse(self, observation: torch.tensor, action: torch.tensor) -> torch.tensor:
    #     pass


    def forward(
        self,
        observation: torch.tensor,
        # n_particles: int = 1,
        with_log_prob: bool = True,
        deterministic: bool = False,
        # use_importance_sampling: bool = False,
    ) -> torch.tensor:
        batch_size = observation.shape[0]
        # observation = observation.float()
        # if n_particles > 1:
        #     observation = observation.unsqueeze(1)
        #     observation = observation.repeat(1, n_particles, 1)
        #     assert observation.dim() == 3
        #     noise_shape = (batch_size, n_particles, self.dim_action)
        # else:
        noise_shape = (batch_size, self.dim_action)

        if self.noise_type == 'uniform':
            """
                noise \in [-1, 1): [0, 1) -> [0, 2) -> [-1, 1)
            """
            if deterministic:
                noise = torch.ones(noise_shape, dtype=torch.float32, device=self.device) * 0.5
            else:
                noise = torch.rand(noise_shape).to(self.device) * 2 - 1
            # noise.requires_grad = True
            log_p_noise = self.log_p_noise

        elif self.noise_type == 'gaussian':
            # z = MultivariateNormal(torch.zeros(self.dim_action), torch.eye(self.dim_action))
            if deterministic:
                noise = torch.zeros(noise_shape, dtype=torch.float32, device=self.device)
            else:
                noise = self.noise_generator.sample((batch_size,)).to(self.device)
            log_p_noise = self.noise_generator.log_prob(noise)

        before_squash, log_det = self.sampler(noise, observation)
        after_squash = torch.tanh(before_squash) if batch_size > 1 else torch.tanh(before_squash).squeeze(0)
        bound_correction = torch.log(self.action_expand * (1 - after_squash.view(noise_shape).pow(2)) + 1e-6).sum(1, keepdim=True)
        # log_det += bound_correction.sum(1, keepdim=True)
        action = after_squash * self.action_expand + self.action_offset

        # print(f'{log_p_noise.shape = }')
        # print(f'{log_det.shape = }')
        if with_log_prob:
            return action, log_p_noise.view(-1,1) - log_det - bound_correction
        else:
            return action


    def act(
        self,
        observation: torch.tensor,
        as_tensor: bool = True,
        with_log_prob: bool = False,
        greedy: bool = False, # dummy
    ) -> np.ndarray:
        if observation.dim() == 1:
            observation = observation.unsqueeze(0)

        # with torch.no_grad():
        action, log_prob = self.forward(observation, deterministic=greedy)
        if not as_tensor:
            action = action.cpu().detach().numpy()

        if with_log_prob:
            return action, log_prob
        else:
            return action


    def log_prob(self, observation: torch.tensor, action: torch.tensor) -> torch.tensor:
        batch_size = observation.shape[0]
        noise_shape = (batch_size, self.dim_action)
        after_squash = (action - self.action_offset) / self.action_expand

        assert -1. <= after_squash.min()
        assert  1. >= after_squash.max()
        eps = 1e-6
        before_squash = torch.atanh(after_squash * (1 - eps))

        noise = self.sampler.inverse(before_squash, observation)
        if self.noise_type == 'uniform':
            log_p_noise = self.log_p_noise
        elif self.noise_type == 'gaussian':
            log_p_noise = self.noise_generator.log_prob(noise)

        _before_squash, log_det = self.sampler(noise, observation)
        # after_squash = torch.tanh(before_squash) if batch_size > 1 else torch.tanh(before_squash).squeeze(0)
        bound_correction = torch.log(self.action_expand * (1 - after_squash.view(noise_shape).pow(2)) + 1e-6).sum(1, keepdim=True)
        # log_det += bound_correction.sum(1, keepdim=True)
        # print(f'reconstruction error = {0.5 * ((before_squash - _before_squash)**2).mean()}')

        return log_p_noise.view(-1,1) - log_det - bound_correction


"""

Real NVP is stacked onto Gaussian Policy.

"""


class ScaleTranslationNN(nn.Module):
    def __init__(self, n_input, n_output, n_hidden):
        """
            s,t: R^{d} -> R^{D-d}
        """
        super(ScaleTranslationNN, self).__init__()
        self.fc1 = nn.Linear(n_input, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3_s = nn.Linear(n_hidden, n_output)
        self.fc3_t = nn.Linear(n_hidden, n_output)

    def forward(self, x):
        hidden = F.relu(self.fc2(F.relu(self.fc1(x))))
        s = torch.tanh(self.fc3_s(hidden))
        t = self.fc3_t(hidden)
        return s, t


class RealNVP(nn.Module):
    def __init__(self, n_flows, data_dim, n_hidden):
        super(RealNVP, self).__init__()
        self.n_flows = n_flows
        self.st_net = torch.nn.ModuleList()

        self.data_dim = data_dim
        self.n_identity = int(data_dim/2)
        print(f'(d,D)={self.n_identity,data_dim}')

        for k in range(n_flows):
            """
                s,t: R^{d} -> R^{D-d}
            """
            self.st_net.append(ScaleTranslationNN(self.n_identity, data_dim - self.n_identity, n_hidden))


    def forward(self, x, n_layers=None):
        if n_layers == None:
            n_layers = self.n_flows

        log_det_jacobian = 0
        for k in range(n_layers):
            if k%2==0:
                x_id = x[:, :self.n_identity]
                x_st = x[:, self.n_identity:]
            else:
                if self.data_dim%2 == 0:
                    n_identity = self.n_identity
                else:
                    n_identity = self.n_identity + 1
                x_st = x[:, :n_identity]
                x_id = x[:, n_identity:]

            s, t = self.st_net[k](x_id)
            x_st = torch.exp(s)*x_st + t

            if k%2==0:
                x = torch.cat([x_id, x_st], dim=1)
            else:
                x = torch.cat([x_st, x_id], dim=1)

            log_det_jacobian += s.sum(dim=1, keepdim=True)

        return x, log_det_jacobian


    def inverse(self, z, n_layers=None):
        if n_layers == None:
            n_layers = self.n_flows

        for k in reversed(range(n_layers)):
            if k%2==0:
                z_id = z[:, :self.n_identity]
                z_st = z[:, self.n_identity:]
            else:
                if self.data_dim%2 == 0:
                    n_identity = self.n_identity
                else:
                    n_identity = self.n_identity + 1
                z_st = z[:, :n_identity]
                z_id = z[:, n_identity:]

            s, t = self.st_net[k](z_id)
            z_st = (z_st - t) / torch.exp(s)

            if k%2==0:
                z = torch.cat([z_id, z_st], dim=1)
            else:
                z = torch.cat([z_st, z_id], dim=1)

        return z



class SequantialGaussRealNVPPolicy(nn.Module):

    def __init__(
        self,
        observation_space,
        action_space,
        n_hidden: list = [256, 256],
        activation=nn.ReLU,
        layer_norm: bool = False,
        noise_type: str = 'uniform',
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        super().__init__()

        self.observation_space = observation_space
        self.action_space = action_space
        self.dim_observation = observation_space.shape[0]
        self.dim_action = action_space.shape[0]

        self.n_hidden = n_hidden
        self.activation = activation
        self.device = device

        self.encoder = MLP(
            n_units=[self.dim_observation]+n_hidden,
            activation=nn.ReLU, output_activation=nn.ReLU,
            layer_norm=layer_norm,
        )
        self.mu_net = torch.nn.Linear(n_hidden[-1], self.dim_action)
        self.log_std_net = torch.nn.Linear(n_hidden[-1], self.dim_action)

        """
        Input: obserbvation, noise (same dim as action)
        Output: action
        """
        self.rnvp = RealNVP(
            n_flows=len(n_hidden), data_dim=self.dim_action,
            n_hidden=n_hidden[-1],
        )

        min_action = action_space.low
        max_action = action_space.high
        print(self.action_space)
        print(max_action)
        print(min_action)
        self.action_expand = torch.as_tensor(
            (max_action - min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        self.action_offset = torch.as_tensor(
            (max_action + min_action) / 2.0,
            dtype=torch.float32, device=self.device
        )
        print(self.action_expand)
        print(self.action_offset)
        print(f'action lower bound = {min_action}')
        print(f'action upper bound = {max_action}')



    def _squash_and_exp_log_std(self, log_std: torch.tensor) -> torch.tensor:
        LOG_STD_MAX = 2
        LOG_STD_MIN = -5
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats
        return torch.exp(log_std)


    def _squash(self, before_squash: torch.tensor, batch_size: int) -> torch.tensor:
        return torch.tanh(before_squash) if batch_size > 1 else torch.tanh(before_squash).squeeze(0)


    def _get_prior(self, observation: torch.tensor) -> Normal:
        embed   = self.encoder(observation.float())
        mu      = self.mu_net(embed)
        log_std = self.log_std_net(embed)
        std = self._squash_and_exp_log_std(log_std)
        return Normal(mu, std)


    def _get_bound_correction(self, after_squash: torch.tensor, noise_shape: ty.Tuple[int]) -> torch.tensor:
        # bound_correction = torch.log(self.action_expand * (1 - after_squash.view(noise_shape).pow(2)) + 1e-6)
        # log_det += bound_correction.sum(1, keepdim=True)
        # bound_correction = torch.log(self.action_expand * (1 - after_squash.view(noise_shape).pow(2)) + 1e-6)
        # log_det += bound_correction.sum(1, keepdim=True)
        return torch.log(self.action_expand * (1 - after_squash.view(noise_shape).pow(2)) + 1e-6).sum(1, keepdim=True)


    def _log_prob() -> torch.tensor:
        return


    def forward(
        self,
        observation: torch.tensor,
        as_tensor: bool = True,
        with_log_prob: bool = True,
        deterministic: bool = False,
        # use_importance_sampling: bool = False,
    ) -> torch.tensor:
        batch_size = observation.shape[0]
        noise_shape = (batch_size, self.dim_action)

        prior = self._get_prior(observation)

        if deterministic:
            noise = prior.loc
        else:
            noise = prior.rsample()
        before_squash, log_det_rnvp = self.rnvp(noise)

        after_squash = self._squash(before_squash, batch_size)
        action = after_squash * self.action_expand + self.action_offset
        if not as_tensor:
            action = action.cpu().detach().numpy()

        if with_log_prob:
            log_prior = prior.log_prob(noise).sum(1, keepdim=True)
            log_det_squash = self._get_bound_correction(after_squash, noise_shape)
            return action, log_prior - log_det_rnvp - log_det_squash
        else:
            return action


    def act(
        self,
        observation: torch.tensor,
        as_tensor: bool = True,
        with_log_prob: bool = False,
        greedy: bool = False, # not necessary greedy
    ) -> np.ndarray:
        if observation.dim() == 1:
            observation = observation.unsqueeze(0)

        return self.forward(observation, as_tensor=as_tensor, with_log_prob=with_log_prob, deterministic=greedy)


    def log_prob(self, observation: torch.tensor, action: torch.tensor) -> torch.tensor:
        batch_size = observation.shape[0]
        noise_shape = (batch_size, self.dim_action)

        prior = self._get_prior(observation)

        if action is None:
            noise = prior.rsample()
        else:
            after_squash = (action - self.action_offset) / self.action_expand
            assert -1. <= after_squash.min()
            assert  1. >= after_squash.max()
            eps = 1e-6
            before_squash = torch.atanh(after_squash * (1 - eps))
            noise = self.rnvp.inverse(before_squash)

        # print(f'{noise.shape = }')
        # print(f'{prior.loc.shape = }')
        log_prior = prior.log_prob(noise).sum(1, keepdim=True)
        # print(f'{log_prior.shape = }')
        _before_squash, log_det_rnvp = self.rnvp(noise)
        # print(f'{log_det_rnvp.shape = }')
        # after_squash = self._squash(before_squash, batch_size)
        log_det_squash = self._get_bound_correction(after_squash, noise_shape)
        # log_det = log_det.sum(1, keepdim=True)
        # print(f'{log_det_squash.shape = }')
        return log_prior - log_det_rnvp - log_det_squash

        # print(f'reconstruction error = {0.5 * ((before_squash - _before_squash)**2).mean()}')
        # return log_prior.view(-1,1) - log_det
