import torch


class as_eval:
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        for model in self.models:
            model.eval()

    def __exit__(self, exc_type, exc_val, exc_tb):
        for model in self.models:
            model.train()


def get_gradient_angle(loss1, loss2, net):
    net.zero_grad()
    grads1 = torch.autograd.grad(
        loss1.mean(),
        net.parameters(),
        create_graph=False,
        allow_unused=True,
        retain_graph=True,
    )
    grads1 = (
        torch.cat([grad.view(-1) for grad in grads1 if grad is not None])
        .clone()
        .detach()
    )
    net.zero_grad()
    grads2 = torch.autograd.grad(
        loss2.mean(),
        net.parameters(),
        create_graph=False,
        allow_unused=True,
        retain_graph=True,
    )
    grads2 = (
        torch.cat([grad.view(-1) for grad in grads2 if grad is not None])
        .clone()
        .detach()
    )
    net.zero_grad()
    return torch.dot(grads1, grads2) / (grads1.norm() * grads2.norm())
