import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical


def layer_init_std_bias(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class RNDAtariAgent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.network = nn.Sequential(
            layer_init_std_bias(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init_std_bias(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init_std_bias(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init_std_bias(nn.Linear(64 * 7 * 7, 256)),
            nn.ReLU(),
            layer_init_std_bias(nn.Linear(256, 448)),
            nn.ReLU(),
        )
        self.extra_layer = nn.Sequential(layer_init_std_bias(nn.Linear(448, 448), std=0.1), nn.ReLU())
        self.actor = nn.Sequential(
            layer_init_std_bias(nn.Linear(448, 448), std=0.01),
            nn.ReLU(),
            layer_init_std_bias(nn.Linear(448, envs.single_action_space.n), std=0.01),
        )
        self.critic_ext = layer_init_std_bias(nn.Linear(448, 1), std=0.01)
        self.critic_int = layer_init_std_bias(nn.Linear(448, 1), std=0.01)

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        features = self.extra_layer(hidden)
        if action is None:
            action = probs.sample()
        return (
            action,
            probs.log_prob(action),
            probs.entropy(),
            self.critic_ext(features + hidden),
            self.critic_int(features + hidden),
        )

    def get_value(self, x):
        hidden = self.network(x / 255.0)
        features = self.extra_layer(hidden)
        return self.critic_ext(features + hidden), self.critic_int(features + hidden)


class RNDVizdoomAgent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.network = nn.Sequential(
            layer_init_std_bias(nn.Conv2d(1, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init_std_bias(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init_std_bias(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init_std_bias(nn.Linear(64 * 7 * 7, 256)),
            nn.ReLU(),
            layer_init_std_bias(nn.Linear(256, 448)),
            nn.ReLU(),
        )
        self.extra_layer = nn.Sequential(layer_init_std_bias(nn.Linear(448, 448), std=0.1), nn.ReLU())
        self.actor = nn.Sequential(
            layer_init_std_bias(nn.Linear(448, 448), std=0.01),
            nn.ReLU(),
            layer_init_std_bias(nn.Linear(448, envs.single_action_space.n), std=0.01),
        )
        self.critic_ext = layer_init_std_bias(nn.Linear(448, 1), std=0.01)
        self.critic_int = layer_init_std_bias(nn.Linear(448, 1), std=0.01)

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        features = self.extra_layer(hidden)
        if action is None:
            action = probs.sample()
        return (
            action,
            probs.log_prob(action),
            probs.entropy(),
            self.critic_ext(features + hidden),
            self.critic_int(features + hidden),
        )

    def get_value(self, x):
        hidden = self.network(x / 255.0)
        features = self.extra_layer(hidden)
        return self.critic_ext(features + hidden), self.critic_int(features + hidden)


class RNDAtariModel(nn.Module):
    def __init__(self, envs):
        super().__init__()

        feature_output = 7 * 7 * 64

        # Prediction network
        self.predictor = nn.Sequential(
            layer_init_std_bias(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=8, stride=4)),
            nn.LeakyReLU(),
            layer_init_std_bias(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)),
            nn.LeakyReLU(),
            layer_init_std_bias(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)),
            nn.LeakyReLU(),
            nn.Flatten(),
            layer_init_std_bias(nn.Linear(feature_output, 512)),
            nn.ReLU(),
            layer_init_std_bias(nn.Linear(512, 512)),
            nn.ReLU(),
            layer_init_std_bias(nn.Linear(512, 512)),
        )

        # Target network
        self.target = nn.Sequential(
            layer_init_std_bias(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=8, stride=4)),
            nn.LeakyReLU(),
            layer_init_std_bias(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)),
            nn.LeakyReLU(),
            layer_init_std_bias(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)),
            nn.LeakyReLU(),
            nn.Flatten(),
            layer_init_std_bias(nn.Linear(feature_output, 512)),
        )

        # target network is not trainable
        for param in self.target.parameters():
            param.requires_grad = False

    def forward(self, next_obs):
        target_feature = self.target(next_obs)
        predict_feature = self.predictor(next_obs)

        return predict_feature, target_feature
