"""
Defines torch modules related to standard kinds of policies.
"""

import logging

import numpy as np
import torch
import torch.nn.functional as F

from . import op

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


class NNStochasticPolicy(torch.nn.Module):
    """Base class for stochastic policies."""
    def to_raw_sample(self, sample):
        raise NotImplementedError(f"log_prob not implemented for class {type(self).__name__}")
    def log_prob(self, raw_sample, x, dist=None):
        raise NotImplementedError(f"log_prob not implemented for class {type(self).__name__}")
    def noise_of(self, raw_sample, x):
        raise NotImplementedError(f"noise_of not implemented for class {type(self).__name__}")
    def forward(self, x, deterministic=False, with_logprob=True, noise=None, with_info=False):
        raise NotImplementedError(f"forward not implemented for class {type(self).__name__}")


class NNGaussianPolicy(NNStochasticPolicy):
    """
    A Gaussian policy, outputting both the state and the probability of an
    action on forward pass.

    If tanh_refit is specified, then output from the gaussian distribution will
    be put through tanh and then refitted according to the provided module.
    """
    def __init__(self, common_head, mu_head, log_std_head,
                 tanh_refit : op.NNTanhRefit = None,
                 log_std_min=-20.0, log_std_max=2.0):
        super().__init__()
        self.common_head = common_head
        self.mu_head = mu_head
        self.log_std_head = log_std_head
        self.tanh_refit = tanh_refit
        self.log_std_min = float(log_std_min)
        self.log_std_max = float(log_std_max)
        #self.normal = torch.distributions.normal.Normal(0, 1)

    def log_prob(self, raw_sample, x, dist=None):
        if dist is None:
            c = self.common_head(x)
            mu = self.mu_head(c)
            std = torch.clamp(self.log_std_head(c), self.log_std_min, self.log_std_max).exp()
            dist = torch.distributions.normal.Normal(mu, std)

        logp = dist.log_prob(raw_sample).sum(axis=-1)
        if self.tanh_refit is not None:
            # tanh correction from https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py#L60
            logp -= (2*(np.log(2) - raw_sample - F.softplus(-2*raw_sample))).sum(axis=1)

        return logp

    def noise_of(self, raw_sample, x):
        c = self.common_head(x)
        mu = self.mu_head(c)
        std = torch.max(self.std_head(c), torch.tensor(self.minimum_std))

        noise = torch.div(raw_sample - mu, std)
        return noise

    def to_raw_sample(self, sample):
        """This should not have to be used under normal circumstances. Just
        save the outputed logprob instead."""
        if self.tanh_refit is not None:
            # This bound ensures numerical stability
            bound = torch.tensor(1 - 1e-7)
            raw_sample = torch.arctanh(torch.clamp(self.tanh_refit.undo(sample), -bound, bound))
        else:
            raw_sample = sample
        return raw_sample

    def forward(self, x, deterministic=False, with_logprob=True, noise=None, with_info=False):
        c = self.common_head(x)
        mu = self.mu_head(c)
        std = torch.clamp(self.log_std_head(c), self.log_std_min, self.log_std_max).exp()

        #noise = self.normal.rsample(sample_shape=mu.shape)
        #out = mu + std * noise
        dist = torch.distributions.normal.Normal(mu, std)

        if deterministic:
            sample = mu
        else:
            if noise is None:
                sample = dist.rsample()
            else:
                sample = mu + std * noise

        raw_sample = sample

        if with_logprob:
            logp = self.log_prob(raw_sample, x, dist=dist)
        else:
            logp = None

        if self.tanh_refit is not None:
            sample = torch.tanh(sample)
            sample = self.tanh_refit(sample)

        if with_info:
            with torch.no_grad():
                info = {
                    "std": std.clone().detach(),
                }
        else:
            info = {}

        return sample, raw_sample, logp, info


class NNCategoricalPolicy(NNStochasticPolicy):
    """
    A Categorical policy
    """
    def __init__(self, net):
        super().__init__()
        self.net = net

    def log_prob(self, raw_sample, x, dist=None):
        if dist is None:
            logits = self.net(x)
            dist = torch.distributions.categorical.Categorical(logits=logits)

        logp = dist.log_prob(raw_sample).sum(axis=-1)

        return logp

    def to_raw_sample(self, sample):
        """This should not have to be used under normal circumstances. Just
        save the outputed logprob instead."""
        # No conversions needed here
        raw_sample = sample
        return raw_sample

    def forward(self, x, deterministic=False, with_logprob=True, noise=None, with_info=False):
        """
        Currently unsupported:
         - noise (what should the noise be here?)
        """
        logits = self.net(x)

        if deterministic:
            sample = logits.argmax(axis=-1)
        else:
            dist = torch.distributions.categorical.Categorical(logits=logits)
            sample = dist.sample()

        raw_sample = sample

        if with_logprob:
            logp = self.log_prob(raw_sample, x, dist=dist)
        else:
            logp = None

        return sample, raw_sample, logp

        info = {}

        return sample, raw_sample, logp, info



class NNStochasticVarLenPolicy(torch.nn.Module):
    """Base class for stochastic policies."""
    def forward(self, s_obs, a_mem, deterministic=False, with_logprob=True, noise=None, with_info=False):
        raise NotImplementedError(f"forward not implemented for class {type(self).__name__}")




class NNGaussianMultiVarLengthPolicy(NNStochasticVarLenPolicy):
    """
    A Gaussian policy taking a (s_obs: [N, S], a_mem: [N, L, A], len: [N]) input.

    If tanh_refit is specified, then output from the gaussian distribution will
    be put through tanh and then refitted according to the provided module.
    """
    def __init__(self, encoder_heads, common_head, mu_head, log_std_head,
                 tanh_refit : op.NNTanhRefit = None,
                 log_std_min=-20.0, log_std_max=2.0):
        super().__init__()
        self.encoder_heads = torch.nn.ModuleList(encoder_heads)
        self.common_head = common_head
        self.mu_head = mu_head
        self.log_std_head = log_std_head
        self.tanh_refit = tanh_refit
        self.log_std_min = float(log_std_min)
        self.log_std_max = float(log_std_max)
        #self.normal = torch.distributions.normal.Normal(0, 1)

    def forward(self, s_obs, a_mem, a_memlen, deterministic=False, with_logprob=True, noise=None, with_info=False):
        """
        s_obs: [N, S] float
        a_mem: [N, L, A] float
        a_memlen: [N] long
        """
        (N, S) = s_obs.shape
        a_memlen = a_memlen.cpu().numpy() # We are not differentiating through this anyway...
        enclist = []
        for i in range(N):
            n = a_memlen[i].item()
            #LOG.debug(f"i: {i} | n: {n}")
            flatinput = torch.cat((s_obs[i], a_mem[i, :n].flatten()))
            # enchead[0] is for delay=1, etc.
            ei = self.encoder_heads[n-1](flatinput.unsqueeze(0)).squeeze(0) # [H]
            enclist.append(ei)
        e = torch.stack(enclist)

        c = self.common_head(e)
        mu = self.mu_head(c)
        std = torch.clamp(self.log_std_head(c), self.log_std_min, self.log_std_max).exp()

        #noise = self.normal.rsample(sample_shape=mu.shape)
        #out = mu + std * noise
        dist = torch.distributions.normal.Normal(mu, std)

        if deterministic:
            sample = mu
        else:
            if noise is None:
                sample = dist.rsample()
            else:
                sample = mu + std * noise

        raw_sample = sample

        if with_logprob:
            #logp = self.log_prob(raw_sample, x, dist=dist)
            logp = dist.log_prob(raw_sample).sum(axis=-1)
            if self.tanh_refit is not None:
                # tanh correction from https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py#L60
                logp -= (2*(np.log(2) - raw_sample - F.softplus(-2*raw_sample))).sum(axis=1)
        else:
            logp = None

        if self.tanh_refit is not None:
            sample = torch.tanh(sample)
            sample = self.tanh_refit(sample)

        if with_info:
            with torch.no_grad():
                info = {
                    "std": std.clone().detach(),
                }
        else:
            info = {}

        return sample, raw_sample, logp, info


class NNGaussianTransformerVarLengthPolicy(NNStochasticVarLenPolicy):
    """
    A Gaussian policy taking a (s_obs: [N, S], a_mem: [N, L, A], len: [N]) input.

    Constructs a mask and passes the data through a transformer.
    Should be an instance of nn.TransformerEncoder

    pos_embed is a function that, given an integer L, returns an [L, H]
    positional embedding tensor.
    """
    def __init__(self,
                 state_embed, act_embed, pos_embed,
                 enc_transformer,
                 n_heads,
                 mu_head, log_std_head,
                 tanh_refit : op.NNTanhRefit = None,
                 log_std_min=-20.0, log_std_max=2.0):
        super().__init__()
        self.state_embed = state_embed
        self.act_embed = act_embed
        self.pos_embed = pos_embed
        self.enc_transformer = enc_transformer
        self.n_heads = n_heads
        self.mu_head = mu_head
        self.log_std_head = log_std_head
        self.tanh_refit = tanh_refit
        self.log_std_min = float(log_std_min)
        self.log_std_max = float(log_std_max)
        #self.normal = torch.distributions.normal.Normal(0, 1)

    def forward(self, s_obs, a_mem, a_memlen, deterministic=False, with_logprob=True, noise=None, with_info=False):
        """
        s_obs: [N, S] float
        a_mem: [N, L, A] float
        a_memlen: [N] long

        a_memlen values must be between 1 and L. (inclusively)
        """
        enc_obs = self.state_embed(s_obs) # [N, H]
        (N, H) = enc_obs.shape
        (N, L, A) = a_mem.shape
        enc_act = self.act_embed(a_mem.reshape(N * L, A)).reshape(N, L, H) # [N, L, H]

        enc_pos = self.pos_embed(torch.arange(L, device=s_obs.device)) # [L, H]
        enc_act = enc_act + enc_pos.unsqueeze(0)

        tx_input = torch.cat((enc_obs.unsqueeze(1), enc_act), dim=1) # [N, L + 1, H]
        #LOG.debug(f"tx_input.shape: {tx_input.shape} = ({(N,L+1,H)})")
        tx_mask = torch.arange(L + 1, device=s_obs.device).unsqueeze(0).repeat(N, 1) <= a_memlen.unsqueeze(-1) # [N, L + 1]

        # Convert the mask to a causal mask of dim [N, L + 1, L + 1]
        tx_mask = torch.tril(tx_mask.unsqueeze(-1) * tx_mask.unsqueeze(-2))

        # Make sure that we have the same mask for all layers
        tx_mask = tx_mask.repeat(self.n_heads, 1, 1)

        # Assuming that the transformer is batch_first
        embeds = self.enc_transformer(tx_input, tx_mask) # [N, L + 1, H]

        # Gather the outputs of interest
        e = torch.gather(embeds, 1, a_memlen[...,None,None].repeat(1,1,H)).squeeze(1)
        assert e.shape == (N, H)

        mu = self.mu_head(e)
        std = torch.clamp(self.log_std_head(e), self.log_std_min, self.log_std_max).exp()

        #noise = self.normal.rsample(sample_shape=mu.shape)
        #out = mu + std * noise
        dist = torch.distributions.normal.Normal(mu, std)

        if deterministic:
            sample = mu
        else:
            if noise is None:
                sample = dist.rsample()
            else:
                sample = mu + std * noise

        raw_sample = sample

        if with_logprob:
            #logp = self.log_prob(raw_sample, x, dist=dist)
            logp = dist.log_prob(raw_sample).sum(axis=-1)
            if self.tanh_refit is not None:
                # tanh correction from https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py#L60
                logp -= (2*(np.log(2) - raw_sample - F.softplus(-2*raw_sample))).sum(axis=1)
        else:
            logp = None

        if self.tanh_refit is not None:
            sample = torch.tanh(sample)
            sample = self.tanh_refit(sample)

        if with_info:
            with torch.no_grad():
                info = {
                    "std": std.clone().detach(),
                }
        else:
            info = {}

        return sample, raw_sample, logp, info

