import torch
from torch import nn
from networks.utils.activations import get_activation


class RND(nn.Module):
    """
    Our exploration bonus is based on the observation that neural networks tend to have significantly
    lower prediction errors on examples similar to those on which they have been trained. This motivates
    the use of prediction errors of networks trained on the agent’s past experience to quantify the novelty
    of new experience.
    """
    def __init__(self, input_shape, output_shape, hidden_dims, activation, lr, device):
        super().__init__()

        self.random_target_net = nn.ModuleList([nn.Linear(input_shape, hidden_dims[0])])

        for i in range(1, len(hidden_dims)):
            self.random_target_net.append(get_activation(activation))
            self.random_target_net.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))

        self.random_target_net = nn.Sequential(*self.random_target_net)

        self.predictor_net = nn.ModuleList([nn.Linear(input_shape, hidden_dims[0])])

        for i in range(1, len(hidden_dims)):
            self.predictor_net.append(get_activation(activation))
            self.predictor_net.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))

        self.predictor_net = nn.Sequential(*self.predictor_net)

        self.to(device)

        self.optimizer = torch.optim.Adam(self.predictor_net.parameters(), lr=lr)

    def train(self, batch):
        """"""
        with torch.no_grad():
            target = self.random_target_net(batch)

        pred = self.predictor_net(batch)

        loss = (target - pred).pow(2).sum(-1).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    @torch.no_grad()
    def compute_surprise(self, batch):
        """"""
        target = self.random_target_net(batch)
        pred = self.predictor_net(batch)
        surprise = (target - pred).pow(2).sum(-1).mean()
        return surprise
