import torch
from torch import autograd
import collections
import types


def gradient(outputs, variables, retain_graph=True, create_graph=True, to_cpu=False,
             **kwargs):
    assert isinstance(outputs, collections.Iterable), \
        f'invalid type of variables: {type(outputs)}'
    assert isinstance(variables, collections.Iterable), \
        f'invalid type of outputs: {type(outputs)}'
    if isinstance(outputs, torch.Tensor) and outputs.shape == ():
        grads = autograd.grad(outputs, variables,
                              retain_graph=retain_graph,
                              create_graph=create_graph,
                              **kwargs
                              )
        if to_cpu:
            grads = [grad.detach().to('cpu') for grad in grads]
        else:
            grads = list(grads)
    else:
        grads = [gradient(output, variables, retain_graph=retain_graph,
                          create_graph=create_graph, to_cpu=to_cpu, **kwargs)
                 for output in outputs]

    if len(grads) == 1:
        grads = grads[0]
    
    return grads


def _flatten(t: torch.Tensor) -> torch.Tensor:
    return torch.flatten(t)


def flatten(t) -> torch.Tensor:
    if isinstance(t, torch.Tensor):
        return _flatten(t)
    else:
        assert isinstance(t, collections.Iterable), f'type {type(t)}'
        return torch.cat([flatten(x) for x in t])


def _var_list(variables) -> list:
    if isinstance(variables, types.GeneratorType):
        variables = list(variables)
    if not isinstance(variables, collections.Iterable):
        variables = [variables]

    return variables


def hessian(f: torch.nn, x: torch.Tensor, variables=None,
            hierarchical: bool = False) -> torch.Tensor:
    """
    http: // learn2learn.net /
    """
    if variables is None:
        variables = x
    variables = _var_list(variables)

    # [Note] This causes gradient to fail
    # variables = [_flatten(var) for var in variables]
    # variables = torch.cat(variables)

    y = f(x)
    grads = gradient(y, variables)
    grads = gradient(grads, variables)

    if not hierarchical:
        n_var = sum([len(_flatten(var)) for var in variables])
        grads = torch.reshape(flatten(grads), (n_var, n_var))

    return grads


def inference_and_hgrad(f: torch.nn, x: torch.Tensor, *vars_list)\
        -> torch.Tensor:
    assert len(vars_list) > 0
    y = f(x)
    grads = higher_grad(y, *vars_list)
    return grads


def higher_grad(y: torch.Tensor, *vars_list, **kwargs) -> torch.Tensor:
    """
    Return:
    grads.shape is y.shape + (# elements in vars_list[n], # elements in vars_list[n-1],...
    """
    # if isinstance(y, collections.Iterable):
    #     return [higher_grad(z, *vars_list) for z in y]

    assert len(vars_list) > 0
    yshape = list(y.shape)
    variables = _var_list(vars_list[0])
    shape = [len(flatten(variables))]
    grads = gradient(y, variables, **kwargs)

    if len(vars_list) > 1:
        for v in vars_list[1:]:
            variables = _var_list(v)
            shape.append(len(flatten(variables)))
            grads = gradient(grads, variables, **kwargs)

    grads = torch.reshape(flatten(grads), yshape + shape)
    # grads = torch.permute(grads, tuple(range(len(shape)))[::-1])
    order = list(range(len(yshape))) + list(range(len(yshape), len(yshape) + len(shape)))[::-1]
    grads = torch.permute(grads, order)
    return grads
