from contextlib import nullcontext

import torch
from torch import nn, optim


# Random network distillation wrapper
class RNDModel(nn.Module):
    def __init__(self, base_model_fn, message_model_fn, flags):
        super().__init__()
        self.random_network = base_model_fn()
        self.learner_network = base_model_fn()
        self.flags = flags

        # Freze the random network
        for param in self.random_network.parameters():
            param.requires_grad = False

        if self.flags.separate_message_novelty:
            self.message_random_network = message_model_fn()
            self.message_learner_network = message_model_fn()

            # Freze the random network
            for param in self.message_random_network.parameters():
                param.requires_grad = False

        # Define the optimizers. This is separate from langexplore.optimizers
        self.optimizer = optim.RMSprop(
            self.learner_network.parameters(),
            lr=self.flags.rnd_lr,
            momentum=0,
            eps=1e-5,
            alpha=0.99,
        )

        if self.flags.separate_message_novelty:
            self.message_optimizer = optim.RMSprop(
                self.message_learner_network.parameters(),
                lr=self.flags.separate_message_rnd_lr,
                momentum=0,
                eps=1e-5,
                alpha=0.99,
            )

    def forward(self, *args, optimize=False, **kwargs):
        """Compute novelty as L2 distance between random network and learner network"""
        if optimize:
            ctx = nullcontext()
        else:
            ctx = torch.no_grad()
        with ctx:
            with torch.no_grad():
                targets = self.random_network(*args, **kwargs)
            preds = self.learner_network(*args, **kwargs)
            novelty = torch.norm(targets.detach() - preds, dim=-1, p=2)
            loss = novelty.mean(1).sum() * 0.1  # RND LOSS COEF = 0.1

            if self.flags.separate_message_novelty:
                with torch.no_grad():
                    message_targets = self.message_random_network(*args, **kwargs)
                message_preds = self.message_learner_network(*args, **kwargs)
                message_novelty = torch.norm(
                    message_targets.detach() - message_preds, dim=-1, p=2
                )
                message_loss = (
                    message_novelty.mean(1).sum() * 0.1
                )  # RND LOSS COEF = 0.1

        if optimize:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.flags.separate_message_novelty:
                self.message_optimizer.zero_grad()
                message_loss.backward()
                self.message_optimizer.step()

        novelty = novelty.detach()

        if self.flags.separate_message_novelty:
            message_novelty = message_novelty.detach()
        else:
            message_novelty = torch.tensor(0.0)
            message_loss = torch.tensor(0.0)

        return novelty, loss, message_novelty, message_loss
