import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from distributions import Categorical
from utils import init


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class FCNetwork(nn.Module):
    def __init__(self, dims, out_layer=None):
        """
        Creates a network using ReLUs between layers and no activation at the end
        :param dims: tuple in the form of (100, 100, ..., 5). for dim sizes
        """
        super().__init__()
        input_size = dims[0]
        h_sizes = dims[1:]

        mods = [nn.Linear(input_size, h_sizes[0])]
        for i in range(len(h_sizes) - 1):
            mods.append(nn.ReLU())
            mods.append(nn.Linear(h_sizes[i], h_sizes[i + 1]))

        if out_layer:
            mods.append(out_layer)

        self.layers = nn.Sequential(*mods)

    def forward(self, x):
        # Feedforward
        return self.layers(x)

    def hard_update(self, source):
        for target_param, source_param in zip(self.parameters(), source.parameters()):
            target_param.data.copy_(source_param.data)

    def soft_update(self, source, t):
        for target_param, source_param in zip(self.parameters(), source.parameters()):
            target_param.data.copy_((1 - t) * target_param.data + t * source_param.data)


class Policy(nn.Module):
    def __init__(self, obs_space, action_space, base=None, base_kwargs=None):
        super(Policy, self).__init__()

        obs_shape = obs_space.shape

        if base_kwargs is None:
            base_kwargs = {}

        self.base = MLPBase(obs_shape[0], **base_kwargs)

        num_outputs = action_space.n
        self.dist = Categorical(self.base.output_size, num_outputs)

    @property
    def is_recurrent(self):
        return self.base.is_recurrent

    @property
    def recurrent_hidden_state_size(self):
        """Size of rnn_hx."""
        return self.base.recurrent_hidden_state_size

    def forward(self, inputs, rnn_hxs, masks):
        raise NotImplementedError

    def act(self, inputs, rnn_hxs, masks, deterministic=False):
        value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)
        dist = self.dist(actor_features)

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        return value, action, action_log_probs, rnn_hxs

    def get_value(self, inputs, rnn_hxs, masks):
        value, _, _ = self.base(inputs, rnn_hxs, masks)
        return value

    def evaluate_actions(self, inputs, rnn_hxs, masks, action):
        value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)
        dist = self.dist(actor_features)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()

        return value, action_log_probs, dist_entropy, rnn_hxs


class NNBase(nn.Module):
    def __init__(self, recurrent, recurrent_input_size, hidden_size):
        super(NNBase, self).__init__()

        self._hidden_size = hidden_size
        self._recurrent = recurrent

        if recurrent:
            self.gru = nn.GRU(recurrent_input_size, hidden_size)
            for name, param in self.gru.named_parameters():
                if "bias" in name:
                    nn.init.constant_(param, 0)
                elif "weight" in name:
                    nn.init.orthogonal_(param)

    @property
    def is_recurrent(self):
        return self._recurrent

    @property
    def recurrent_hidden_state_size(self):
        if self._recurrent:
            return self._hidden_size
        return 1

    @property
    def output_size(self):
        return self._hidden_size

    def _forward_gru(self, x, hxs, masks):
        if x.size(0) == hxs.size(0):
            x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0))
            x = x.squeeze(0)
            hxs = hxs.squeeze(0)
        else:
            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
            N = hxs.size(0)
            T = int(x.size(0) / N)

            # unflatten
            x = x.view(T, N, x.size(1))

            # Same deal with masks
            masks = masks.view(T, N)

            # Let's figure out which steps in the sequence have a zero for any agent
            # We will always assume t=0 has a zero in it as that makes the logic cleaner
            has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu()

            # +1 to correct the masks[1:]
            if has_zeros.dim() == 0:
                # Deal with scalar
                has_zeros = [has_zeros.item() + 1]
            else:
                has_zeros = (has_zeros + 1).numpy().tolist()

            # add t=0 and t=T to the list
            has_zeros = [0] + has_zeros + [T]

            hxs = hxs.unsqueeze(0)
            outputs = []
            for i in range(len(has_zeros) - 1):
                # We can now process steps that don't have any zeros in masks together!
                # This is much faster
                start_idx = has_zeros[i]
                end_idx = has_zeros[i + 1]

                rnn_scores, hxs = self.gru(x[start_idx:end_idx], hxs * masks[start_idx].view(1, -1, 1))

                outputs.append(rnn_scores)

            # assert len(outputs) == T
            # x is a (T, N, -1) tensor
            x = torch.cat(outputs, dim=0)
            # flatten
            x = x.view(T * N, -1)
            hxs = hxs.squeeze(0)

        return x, hxs


class MLPBase(NNBase):
    def __init__(self, num_inputs, recurrent=False, hidden_size=64):
        super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size)

        if num_inputs == 147:
            self.env_name = "pursuit"
        elif num_inputs == 845:
            self.env_name = "battle"
        elif num_inputs == 500:
            self.env_name = "adversarial_pursuit"

        if recurrent:
            num_inputs = hidden_size

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))

        if self.env_name == "pursuit":
            self.actor = nn.Sequential(
                nn.Conv2d(3, 32, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(32, 64, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(64, 64, [2, 2]),
                nn.ReLU(),
                nn.Flatten(),
                (nn.Linear(1024, 64)),
                nn.ReLU(),
            )

            self.critic = nn.Sequential(
                nn.Conv2d(3, 32, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(32, 64, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(64, 64, [2, 2]),
                nn.ReLU(),
                nn.Flatten(),
                (nn.Linear(1024, 64)),
                nn.ReLU(),
            )
        elif self.env_name == "battle":
            self.actor = nn.Sequential(
                nn.Conv2d(5, 32, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(32, 64, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(64, 64, [2, 2]),
                nn.ReLU(),
                nn.Flatten(),
                (nn.Linear(6400, 64)),
                nn.ReLU(),
            )

            self.critic = nn.Sequential(
                nn.Conv2d(5, 32, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(32, 64, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(64, 64, [2, 2]),
                nn.ReLU(),
                nn.Flatten(),
                (nn.Linear(6400, 64)),
                nn.ReLU(),
            )
        elif self.env_name == "adversarial_pursuit":
            self.actor = nn.Sequential(
                nn.Conv2d(5, 32, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(32, 64, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(64, 64, [2, 2]),
                nn.ReLU(),
                nn.Flatten(),
                (nn.Linear(3136, 64)),
                nn.ReLU(),
            )

            self.critic = nn.Sequential(
                nn.Conv2d(5, 32, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(32, 64, [2, 2]),
                nn.ReLU(),
                nn.Conv2d(64, 64, [2, 2]),
                nn.ReLU(),
                nn.Flatten(),
                (nn.Linear(3136, 64)),
                nn.ReLU(),
            )

        self.critic_linear = init_(nn.Linear(hidden_size, 1))

        self.train()

    def forward(self, inputs, rnn_hxs, masks):
        x = inputs

        if self.is_recurrent:
            x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)

        if self.env_name == "pursuit":
            obs = x.reshape([x.shape[0], 7, 7, 3])
        elif self.env_name == "battle":
            obs = x.reshape([x.shape[0], 13, 13, 5])
        elif self.env_name == "adversarial_pursuit":
            obs = x.reshape([x.shape[0], 10, 10, 5])
        hidden_actor = self.actor(obs.permute(0, 3, 1, 2))
        hidden_critic = self.critic(obs.permute(0, 3, 1, 2))

        return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs
