from typing import Callable

import torch
from torch import Tensor
from torch.nn import Module

from algorithms.convergence_algorithms.utils import ball_perturb
from algorithms.nn.distributions import WeightsDistributionBase
from algorithms.nn.grad import hessian_from_gradient_network


class GradientLoss(Module):
    def __init__(self, grad_network, ball_perturb_size, calc_loss):
        super().__init__()
        self.grad_network = grad_network
        self.ball_perturb_size = ball_perturb_size
        self.calc_loss = calc_loss

    def forward(self, x_i, x_j, y_i, y_j):
        assert len(x_i.shape) <= 2, "Cant handle multiple batches yet"
        x_tag_perturb = (
            ball_perturb(x_i, self.ball_perturb_size, len(x_i), x_i.dtype, x_i.device)
            if self.ball_perturb_size
            else x_i
        )
        x_tag_perturb.requires_grad = True
        grad_on_perturb = self.grad_network(x_tag_perturb)

        x_delta = x_j - x_i
        value = (x_delta * grad_on_perturb).sum(dim=1) + self.taylor_remainder(x_delta)
        target = y_j - y_i

        return self.calc_loss(value, target)

    def taylor_remainder(self, x_delta):
        return 0


class NaturalHessianLoss(GradientLoss):
    def calculate_hessian(self, x):
        return hessian_from_gradient_network(self.grad_network, x)

    def taylor_remainder(self, x_delta):
        hessian = self.calculate_hessian(x_delta)
        return torch.bmm(
            torch.bmm(
                x_delta.reshape((x_delta.shape[0], 1, x_delta.shape[1])), hessian
            ),
            x_delta.reshape((x_delta.shape[0], x_delta.shape[1], 1)),
        ).squeeze()


class HessianWithDifferentNetwork(NaturalHessianLoss):
    def __init__(self, *args, hessian_network, **kwargs):
        super().__init__(*args, **kwargs)
        self.hessian_network = hessian_network

    def calculate_hessian(self, x):
        return self.hessian_network(x)


def loss_with_quantile(
    value: Tensor,
    target: Tensor,
    quantile: int,
    weights_creator: WeightsDistributionBase,
    loss: Callable,
) -> Tensor:
    smallest_element = (value - target).abs().clone().detach()
    weights = weights_creator.distribute_weights(smallest_element, quantile)

    loss = loss(value, target)
    loss = (loss * weights).mean()
    return loss


def wgan_gradient_penalty_loss(
    model: Module,
    generated_data: Tensor,
    real_data: Tensor,
    gradient_penalty_factor: float,
    batch_size: int,
):
    basic_loss = model(generated_data).mean() - model(real_data).mean()
    eps = torch.rand(generated_data.shape, device=generated_data.device)
    interpolation = eps * real_data + (1 - eps) * generated_data
    interp_logits = model(interpolation)
    grad_output = torch.ones_like(interp_logits, device=interp_logits.device)
    gradients = torch.autograd.grad(
        outputs=interp_logits,
        inputs=interpolation,
        grad_outputs=grad_output,
        create_graph=True,
        retain_graph=True,
    )[0]
    gradients = gradients.view(
        batch_size, -1
    )  # we smash all the gradient for each batch together into a single array (relevant in multy dim)
    gradient_penalty = torch.mean((gradients.norm(2, 1) - 1) ** 2)
    return basic_loss + gradient_penalty_factor * gradient_penalty
