from itertools import chain
import torch
import torch.nn as nn

__all__ = ['f_and_jac_fn', 'f_and_jac_fn_amortized', 'f_and_jac_reg_fn']


class FuncAndDiagJac(torch.autograd.Function):
    """Given a f: R^d -> R^d, computes both f(x) and the diagonals of the Jacobian of f(x).
    """

    @staticmethod
    def forward(ctx, exclusive_net, dimwise_net, t, x, flat_params, order,
                dimwise_params=None, dimwise_am_params=None, mask=None, **kwargs):
        ctx.exclusive_net = exclusive_net
        ctx.dimwise_net = dimwise_net
        with torch.enable_grad():
            # t = t.detach().requires_grad_(True)
            x = x.detach().requires_grad_(True)
            nx = exclusive_net(t, x, mask=mask)
            shape = x.shape
            x_ = x.contiguous().view(-1, 1)
            nx = nx.view(x_.shape[0], -1)
            nx_detached = nx.clone().detach().requires_grad_(True)
            output_check = dimwise_net(t, x_, nx_detached, params=dimwise_params,
                                       am_params=dimwise_am_params, mask=mask).view(*shape)
            djac = torch.autograd.grad(output_check.sum(), x, create_graph=True)[0]
            while order > 1:
                djac = torch.autograd.grad(djac.sum(), x, create_graph=True)[0]
                order -= 1
            ctx.save_for_backward(x, nx, nx_detached, output_check, djac, dimwise_params, dimwise_am_params)
            return safe_detach(output_check), safe_detach(djac)

    @staticmethod
    def backward(ctx, grad_output, grad_djac):
        x, nx, nx_detached, output_check, djac, dimwise_params, dimwise_am_params = ctx.saved_tensors

        grad_x = grad_flat_params = grad_dimwise_params = grad_dimwise_am_params = None

        exclusive_net, dimwise_net = ctx.exclusive_net, ctx.dimwise_net

        if dimwise_params is not None:
            f_params = list(exclusive_net.parameters())
        else:
            f_params = list(exclusive_net.parameters()) + list(dimwise_net.parameters())

        if dimwise_params is not None:
            grad_x, grad_nx, grad_dimwise_params, *grad_params = torch.autograd.grad(
                [output_check, djac],
                [x, nx_detached, dimwise_params] + f_params,
                [grad_output, grad_djac],
                retain_graph=True,
                allow_unused=True,
            )
        elif dimwise_am_params is not None:
            grad_x, grad_nx, grad_dimwise_am_params, *grad_params = torch.autograd.grad(
                [output_check, djac],
                [x, nx_detached, dimwise_am_params] + f_params,
                [grad_output, grad_djac],
                retain_graph=True,
                allow_unused=True,
            )
        else:
            grad_x, grad_nx, *grad_params = torch.autograd.grad(
                [output_check, djac],
                [x, nx_detached] + f_params,
                [grad_output, grad_djac],
                retain_graph=True,
                allow_unused=True,
            )
        grad_flat_params = _flatten_convert_none_to_zeros(grad_params, f_params)

        if grad_nx is not None:
            grad_x_from_nx, *grad_params_from_nx = torch.autograd.grad(
                nx, [x] + f_params, grad_nx, retain_graph=True, allow_unused=True
            )
            grad_x = grad_x + grad_x_from_nx
            grad_flat_params = grad_flat_params + _flatten_convert_none_to_zeros(grad_params_from_nx, f_params)

        return None, None, None, grad_x, grad_flat_params, None, grad_dimwise_params, grad_dimwise_am_params, None


def f_and_jac_fn(exclusive_net, dimwise_net, t, x, order=1, **kwargs):

    # We need this in order to access the variables inside this module,
    # since we have no other way of getting variables along the execution path.
    _check_if_nn(exclusive_net, dimwise_net)

    flat_params = _flatten(chain(exclusive_net.parameters(), dimwise_net.parameters()))
    return FuncAndDiagJac.apply(exclusive_net, dimwise_net, t, x, flat_params, order, None, None, None)

def f_and_jac_reg_fn(exclusive_net, dimwise_net, t, x, order=1, mask=None, **kwargs):
    _check_if_nn(exclusive_net, dimwise_net)

    flat_params = _flatten(chain(exclusive_net.parameters(), dimwise_net.parameters()))
    y, jac = FuncAndDiagJac.apply(exclusive_net, dimwise_net, t, x, flat_params, order, None, None, mask)
    sqnorm = jac.view(y.shape[0], -1).pow(2).mean(dim=1)
    return y, jac, sqnorm

def f_and_jac_fn_amortized(exclusive_net, dimwise_net, dimwise_params, t, x, order=1, **kwargs):
    _check_if_nn(exclusive_net, dimwise_net)

    flat_params = _flatten(chain(exclusive_net.parameters()))
    return FuncAndDiagJac.apply(exclusive_net, dimwise_net, t, x, flat_params, order, dimwise_params, None, None)


def f_and_jac_fn_low_rank(exclusive_net, dimwise_net, dimwise_am_params, t, x, order=1, **kwargs):
    _check_if_nn(exclusive_net, dimwise_net)

    flat_params = _flatten(chain(exclusive_net.parameters(), dimwise_net.parameters()))
    return FuncAndDiagJac.apply(exclusive_net, dimwise_net, t, x, flat_params, order, None, dimwise_am_params, None)


# -------------- Helper functions --------------


def safe_detach(tensor):
    return tensor.detach().requires_grad_(tensor.requires_grad)

def _check_if_nn(exclusive_net, dimwise_net):
    if not isinstance(exclusive_net, nn.Module) or not isinstance(dimwise_net, nn.Module):
        raise ValueError('both exclusive_net and dimwise_net are required to be an instance of nn.Module.')

def _flatten(sequence):
    flat = [p.contiguous().view(-1) for p in sequence]
    return torch.cat(flat) if len(flat) > 0 else torch.tensor([])


def _flatten_convert_none_to_zeros(sequence, like_sequence):
    flat = [
        p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1)
        for p, q in zip(sequence, like_sequence)
    ]
    return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
