import pdb
from collections import OrderedDict

import torch
from torchmeta.modules import MetaModule


def bidirectional_gradient_updated_parameters(model,
                                              loss,
                                              params=None,
                                              first_order=False,
                                              r=1e-2):
    """Update of the meta-parameters with one step of gradient descent on the
    loss function.
    Parameters
    ----------
    model : `torchmeta.modules.MetaModule` instance
        The model.
    loss : `torch.Tensor` instance
        The value of the inner-loss. This is the result of the training dataset
        through the loss function.
    params : `collections.OrderedDict` instance, optional
        Dictionary containing the meta-parameters of the model. If `None`, then
        the values stored in `model.meta_named_parameters()` are used. This is
        useful for running multiple steps of gradient descent as the inner-loop.
    step_size : int, `torch.Tensor`, or `collections.OrderedDict` instance (default: 0.5)
        The step size in the gradient update. If an `OrderedDict`, then the
        keys must match the keys in `params`.
    first_order : bool (default: `False`)
        If `True`, then the first order approximation of MAML is used.
    Returns
    -------
    updated_params : `collections.OrderedDict` instance
        Dictionary containing the updated meta-parameters of the model, with one
        gradient update wrt. the inner-loss.
    """
    if not isinstance(model, MetaModule):
        raise ValueError('The model must be an instance of `torchmeta.modules.'
                         'MetaModule`, got `{0}`'.format(type(model)))

    if params is None:
        params = OrderedDict(model.meta_named_parameters())

    grads = torch.autograd.grad(loss,
                                params.values(),
                                create_graph=not first_order,)
    all_grads = torch.cat([x.view(-1) for x in grads])
    epsilon = torch.tensor(r).div(all_grads.norm(p=2))

    updated_params_plus = OrderedDict()
    updated_params_minus = OrderedDict()

    for (name, param), grad in zip(params.items(), grads):
        updated_params_plus[name] = param + epsilon * grad
        updated_params_minus[name] = param - epsilon * grad

    return updated_params_plus, updated_params_minus, epsilon
