from typing import NamedTuple
from typing import Optional, List

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import jit
from torch.distributions import Uniform, Normal, TransformedDistribution, Categorical


# weights initializer
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


def normalized_columns_initializer(weights, std=1.0):
    out = torch.randn(weights.size())
    out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True))
    return out


class TransitionOutput(NamedTuple):
    belief: Tensor
    prior_state: Tensor
    prior_mean: Tensor
    prior_std_dev: Tensor
    posterior_state: Optional[Tensor]
    posterior_mean: Optional[Tensor]
    posterior_std_dev: Optional[Tensor]


class ActorOutput(NamedTuple):
    mu: Tensor
    std_dev: Tensor


class ActorRepeatOutput(NamedTuple):
    logit: Tensor
    prob: Tensor
    dist: Tensor
    entropy: Tensor
    log_prob_mean: Tensor


class ValueNetwork(jit.ScriptModule):
    def __init__(self, belief_size, state_size, hidden_dim):
        super(ValueNetwork, self).__init__()

        self.linear1 = nn.Linear(belief_size + state_size, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, 1)

    @jit.script_method
    def forward(self, belief, state):
        assert len(belief.shape) == 2
        assert len(state.shape) == 2
        assert state.shape[0] == belief.shape[0], 'should have same batch size'

        x = torch.cat((belief, state), dim=1)
        x = F.elu(self.linear1(x))
        x = F.elu(self.linear2(x))
        x = F.elu(self.linear3(x))
        x = self.linear4(x)

        return x


class ActorRepeatNetwork(jit.ScriptModule):
    def __init__(self, belief_size: int, state_size: int, action_size: int, hidden_dim: int, action_repeat_set):
        super(ActorRepeatNetwork, self).__init__()
        self.action_repeat_set = action_repeat_set
        self.belief_size = belief_size
        self.state_size = state_size
        self.action_size = action_size

        self.linear1 = nn.Linear(belief_size + state_size + action_size, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, len(self.action_repeat_set))

        self.apply(weights_init_)
        self.linear5.weight.data = normalized_columns_initializer(self.linear5.weight.data, 0.01)
        self.linear5.bias.data.fill_(0)

    # @jit.script_method
    def forward(self, belief, state, action):
        assert len(belief.shape) == 2
        assert len(state.shape) == 2
        assert len(action.shape) == 2

        assert belief.shape[1] == self.belief_size
        assert state.shape[1] == self.state_size
        assert action.shape[1] == self.action_size
        assert belief.shape[0] == state.shape[0] == action.shape[0]

        hx = torch.cat((belief, state, action), dim=1)
        x = F.elu(self.linear1(hx))
        x = F.elu(self.linear2(x))
        x = F.elu(self.linear3(x))
        x = F.elu(self.linear4(x))
        x = self.linear5(x)
        prob = F.softmax(x, dim=1)
        log_prob = F.log_softmax(x, dim=1)

        return ActorRepeatOutput(x, prob, self.repeat_dist(prob),
                                 -(log_prob * prob).sum(1, keepdim=True),
                                 (log_prob).mean(1, keepdim=True))

    def repeat_dist(self, prob):
        return Categorical(prob)

    def uniform_sample(self, batch_size=1, device='cpu'):
        idxs = np.random.choice(range(len(self.action_repeat_set)), batch_size)
        idxs = torch.LongTensor(idxs).unsqueeze(1).to(device)
        action_repeat_prob = torch.FloatTensor(batch_size, len(self.action_repeat_set)).zero_()
        action_repeat_prob = action_repeat_prob.to(device)
        action_repeat_prob.scatter_(1, idxs, 1)

        repeats_set = torch.zeros((batch_size, len(self.action_repeat_set))).to(device)
        repeats_set[:] = torch.Tensor(self.action_repeat_set).to(device)
        repeats = repeats_set.gather(1, idxs)

        return action_repeat_prob, repeats

    def sample(self, actor_repeat_output, deterministic=False):
        batch_size = actor_repeat_output.prob.shape[0]
        if deterministic:
            idxs = actor_repeat_output.prob.argmax(dim=1, keepdim=True)
        else:
            idxs = actor_repeat_output.dist.sample()
            idxs = idxs.unsqueeze(1)

        repeats_set = torch.zeros((batch_size, len(self.action_repeat_set))).to(idxs.device)
        repeats_set[:] = torch.Tensor(self.action_repeat_set).to(idxs.device)
        repeats = repeats_set.gather(1, idxs)

        action_repeat_prob = torch.FloatTensor(batch_size, len(self.action_repeat_set)).zero_().to(idxs.device)
        action_repeat_prob.scatter_(1, idxs, 1)

        return action_repeat_prob, repeats


class ActorNetwork(jit.ScriptModule):
    def __init__(self, belief_size, state_size, hidden_dim, action_space: gym.spaces.Box,
                 min_std: float = 1e-4, init_std: float = 5., mu_scale=5):
        super(ActorNetwork, self).__init__()
        self.action_space = action_space
        self.action_space_n = action_space.shape[0]
        self._min_std = min_std
        self._init_std = init_std
        self._mu_scale = mu_scale

        self.linear1 = nn.Linear(belief_size + state_size, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = nn.Linear(hidden_dim, hidden_dim)
        self.mu_linear = nn.Linear(hidden_dim, action_space.shape[0])
        self.std_linear = nn.Linear(hidden_dim, action_space.shape[0])

    @jit.script_method
    def forward(self, belief, state):
        assert len(belief.shape) == 2
        assert len(state.shape) == 2
        assert belief.shape[0] == state.shape[0]
        raw_init_std = torch.log(torch.exp(torch.tensor([self._init_std])) - 1.).item()

        hx = torch.cat((belief, state), dim=1)
        x = F.elu(self.linear1(hx))
        x = F.elu(self.linear2(x))
        x = F.elu(self.linear3(x))
        x = F.elu(self.linear4(x))
        mu_logit = self.mu_linear(x)
        std_logit = self.std_linear(x)

        mu = self._mu_scale * torch.tanh(mu_logit / self._mu_scale)
        std = F.softplus(std_logit + raw_init_std) + self._min_std
        return ActorOutput(mu, std)

    def action_sample(self, actor_output=None, batch_size=1, uniform=False, deterministic=False, device='cpu'):
        if uniform:
            actions = Uniform(torch.FloatTensor(self.action_space.low),
                              torch.FloatTensor(self.action_space.high) + 1e-5).sample((batch_size,))
            actions = actions.to(device)
            actions.unsqueeze_(0)
        else:
            action_dist = self.action_dist(actor_output)
            if deterministic:
                actions = action_dist.mode()
            else:
                actions = action_dist.rsample()
        return actions

    def action_dist(self, actor_output):
        dist = Normal(actor_output.mu, actor_output.std_dev)
        dist = TransformedDistribution(dist, TanhBijector())
        dist = torch.distributions.Independent(dist, 1)
        dist = SampleDist(dist)
        return dist


class TransitionNetwork(jit.ScriptModule):
    __constants__ = ['min_std_dev']

    def __init__(self, belief_size, state_size, action_size, action_repeat_size,
                 hidden_size, embedding_size, min_std_dev=0.1):
        super(TransitionNetwork, self).__init__()
        self.belief_size = belief_size
        self.state_size = state_size
        self.action_size = action_size
        self.embedding_size = embedding_size
        self.action_repeat_size = action_repeat_size

        self.min_std_dev = min_std_dev
        self.fc_embed_state_action = nn.Linear(state_size + action_size + action_repeat_size, belief_size)
        self.rnn = nn.GRUCell(belief_size, belief_size)
        self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size)
        self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size)
        self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size)
        self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size)

    @jit.script_method
    def forward(self, prev_belief, prev_state, actions, actions_repeat_prob,
                observations: Optional[torch.Tensor] = None, non_terminals: Optional[torch.Tensor] = None):
        # ensure input shapes
        assert len(prev_belief.shape) == 2 and prev_belief.shape[1] == self.belief_size, 'belief shape is incorrect'
        assert len(prev_state.shape) == 2 and prev_state.shape[1] == self.state_size, 'state shape is incorrect'
        assert len(actions.shape) == 3 and actions.shape[2] == self.action_size, 'action shape is incorrect'
        assert len(actions_repeat_prob.shape) == 3 and actions_repeat_prob.shape[2] == self.action_repeat_size, \
            'action repeat shape is incorrect'
        assert prev_state.shape[0] == prev_belief.shape[0] == actions.shape[1] == actions_repeat_prob.shape[1], \
            'batch shape should be same'
        if observations is not None:
            assert len(observations.shape) == 3 and observations.shape[2] == self.embedding_size
        if non_terminals is not None:
            assert len(non_terminals.shape) == 3 and non_terminals.shape[2] == 1

        T = actions.size(0) + 1
        beliefs = [torch.empty(0)] * T
        prior_states, posterior_states = [torch.empty(0)] * T, [torch.empty(0)] * T
        prior_means, prior_std_devs = [torch.empty(0)] * T, [torch.empty(0)] * T
        posterior_means, posterior_std_devs = [torch.empty(0)] * T, [torch.empty(0)] * T

        beliefs[0] = prev_belief
        prior_states[0] = prev_state
        posterior_states[0] = prev_state

        for t in range(T - 1):
            _state = prior_states[t] if observations is None else posterior_states[t]
            _state = _state if non_terminals is None else _state * non_terminals[t]

            # compute deterministic hidden state
            hidden = self.fc_embed_state_action(torch.cat([_state, actions[t], actions_repeat_prob[t]], dim=1))
            hidden = F.elu(hidden)
            beliefs[t + 1] = self.rnn(hidden, beliefs[t])

            # compute state prior
            hidden = F.elu(self.fc_embed_belief_prior(beliefs[t + 1]))
            prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
            prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev
            prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])

            if observations is not None:
                # Compute state posterior
                t_ = t - 1  # Use t_ to deal with different time indexing for observations
                belief_obs_cat = torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1)
                hidden = F.elu(self.fc_embed_belief_posterior(belief_obs_cat))
                posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
                posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
                posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(
                    posterior_means[t + 1])

        return TransitionOutput(
            torch.stack(beliefs[1:], dim=0),
            torch.stack(prior_states[1:], dim=0),
            torch.stack(prior_means[1:], dim=0),
            torch.stack(prior_std_devs[1:], dim=0),
            torch.stack(posterior_states[1:], dim=0) if observations is not None else None,
            torch.stack(posterior_means[1:], dim=0) if observations is not None else None,
            torch.stack(posterior_std_devs[1:], dim=0) if observations is not None else None,
        )


class RewardNetwork(jit.ScriptModule):
    def __init__(self, belief_size, state_size, hidden_dim):
        super(RewardNetwork, self).__init__()
        self._reward = nn.Sequential(nn.Linear(belief_size + state_size, hidden_dim),
                                     nn.ELU(),
                                     nn.Linear(hidden_dim, hidden_dim),
                                     nn.ELU(),
                                     nn.Linear(hidden_dim, 1))

    @jit.script_method
    def forward(self, belief, state):
        assert len(belief.shape) == 2
        assert len(state.shape) == 2
        assert state.shape[0] == belief.shape[0], 'should have same batch size'

        x = torch.cat((state, belief), dim=1)
        return self._reward(x)


class ObservationNetwork(jit.ScriptModule):
    def __init__(self, obs_size, belief_size, state_size, embedding_size):
        super(ObservationNetwork, self).__init__()
        self.fc1 = nn.Linear(belief_size + state_size, embedding_size)
        self.fc2 = nn.Linear(embedding_size, embedding_size)
        self.fc3 = nn.Linear(embedding_size, obs_size)

    @jit.script_method
    def forward(self, belief, state):
        hidden = F.elu(self.fc1(torch.cat([belief, state], dim=1)))
        hidden = F.elu(self.fc2(hidden))
        observation = self.fc3(hidden)
        return observation


class Encoder(jit.ScriptModule):
    def __init__(self, observation_size, embedding_size):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(observation_size, embedding_size)
        self.fc2 = nn.Linear(embedding_size, embedding_size)
        self.fc3 = nn.Linear(embedding_size, embedding_size)

    @jit.script_method
    def forward(self, observation):
        hidden = F.elu(self.fc1(observation))
        hidden = F.elu(self.fc2(hidden))
        hidden = F.elu(hidden)
        return hidden


class DreamerNetwork(jit.ScriptModule):
    def __init__(self, obs_size, belief_size, state_size, hidden_size, embedding_size, action_space,
                 action_repeat_set: List):
        super(DreamerNetwork, self).__init__()

        self.belief_size = belief_size
        self.state_size = state_size
        self.action_space = action_space
        self.num_actions = action_space.shape[0]
        self.embedding_size = embedding_size
        self.action_repeat_set = sorted(action_repeat_set)

        self.encoder = Encoder(obs_size, embedding_size)
        self.observation = ObservationNetwork(obs_size, belief_size, state_size, embedding_size)
        self.transition = TransitionNetwork(belief_size, state_size, action_space.shape[0], len(action_repeat_set),
                                            hidden_size, embedding_size)
        self.reward = RewardNetwork(self.belief_size, self.state_size, hidden_size)

        self.actor_repeat = ActorRepeatNetwork(self.belief_size, self.state_size, action_space.shape[0],
                                               hidden_size, self.action_repeat_set)
        self.actor = ActorNetwork(self.belief_size, self.state_size, hidden_size, action_space)
        self.value = ValueNetwork(self.belief_size, self.state_size, hidden_size)

    @jit.script_method
    def init_action(self, batch_size: int = 1):
        return torch.zeros((batch_size, self.num_actions))

    @jit.script_method
    def init_action_repeat_one_hot(self, batch_size: int = 1):
        x = torch.zeros((batch_size, len(self.action_repeat_set)))
        x[:, 0] = 1
        return x

    @jit.script_method
    def init_belief(self, batch_size: int = 1):
        return torch.zeros((batch_size, self.belief_size))

    @jit.script_method
    def init_state(self, batch_size: int = 1):
        return torch.zeros((batch_size, self.state_size))


# "atanh", "TanhBijector" and "SampleDist" are from the following repo
def atanh(x):
    return 0.5 * torch.log((1 + x) / (1 - x))


class TanhBijector(torch.distributions.Transform):
    def __init__(self):
        super().__init__()
        self.bijective = True

    @property
    def sign(self):
        return 1.

    def _call(self, x):
        return torch.tanh(x)

    def _inverse(self, y: torch.Tensor):
        y = torch.where((torch.abs(y) <= 1.),
                        torch.clamp(y, -0.99999997, 0.99999997), y)
        y = atanh(y)
        return y

    def log_abs_det_jacobian(self, x, y):
        return 2. * (np.log(2) - x - F.softplus(-2. * x))


class SampleDist:
    def __init__(self, dist, samples=100):
        self._dist = dist
        self._samples = samples

    @property
    def name(self):
        return 'SampleDist'

    def __getattr__(self, name):
        return getattr(self._dist, name)

    def mean(self, dist):
        sample = dist.rsample()
        return torch.mean(sample, 0)

    def mode(self):
        dist = self._dist.expand((self._samples, *self._dist.batch_shape))
        sample = dist.rsample()
        logprob = dist.log_prob(sample)
        batch_size = sample.size(1)
        feature_size = sample.size(2)
        indices = torch.argmax(logprob, dim=0).reshape(1, batch_size, 1).expand(1, batch_size, feature_size)
        return torch.gather(sample, 0, indices).squeeze(0)

    def entropy(self, sample=None):
        dist = self._dist.expand((self._samples, *self._dist.batch_shape))
        if sample is None:
            sample = dist.rsample()
        logprob = dist.log_prob(sample)
        return -torch.mean(logprob, 0)

    def sample(self):
        return self._dist.sample()
