import torch
import torch.nn.functional as F
from collections import OrderedDict

from functional.cross_entropy import cross_entropy

def inner_maml(model, inputs, labels, nstep_inner=5, lr_inner=0.4, first_order=False, param_init=None):
    # taking first step

    if param_init is None:
        model.zero_grad()
    else:
        for param in param_init.values():
            if param.grad is not None:
                param.grad.zero_()

    outputs = model(inputs, param=param_init)
    loss = F.cross_entropy(input=outputs, target=labels, reduction='mean')
    if param_init is None:
        gradients = torch.autograd.grad(loss, model.meta_parameters(), create_graph=not first_order)
        param_inner = OrderedDict([
            (name, param - lr_inner * grad)
            for (name, param), grad in zip(model.meta_named_parameters(), gradients)
        ])
    else:
        gradients = torch.autograd.grad(loss, param_init.values(), create_graph=not first_order)
        param_inner = OrderedDict([
            (name, param - lr_inner * grad)
            for (name, param), grad in zip(param_init.items(), gradients)
        ])

    # taking remaining steps
    for _ in range(nstep_inner - 1):
        if param_init is None:
            model.zero_grad()
        else:
            for param in param_init.values():
                if param.grad is not None:
                    param.grad.zero_()
        outputs = model(inputs, param=param_inner)
        loss = F.cross_entropy(input=outputs, target=labels, reduction='mean')
        gradients = torch.autograd.grad(loss, param_inner.values(), create_graph=not first_order)
        param_inner = OrderedDict([
            (name, param - lr_inner * grad) for (name, param), grad in zip(param_inner.items(), gradients)
        ])

    return param_inner


def inner_maml_meanvi(var_obj, inputs, labels, nstep_inner=5, lr_inner=0.4, first_order=False):
    # taking first step

    # zero grad
    for mu, cov in zip(var_obj.mean.values(), var_obj.covar.values()):
        if mu.grad is not None and cov.grad is not None:
            mu.grad.zero_()
            cov.grad.zero_()

    # shape (n_sample, batch, num_way)
    outputs = var_obj.model(inputs, mean=var_obj.mean, cov=var_obj.exp_covar(var_obj.covar))
    # shape (n_sample)
    loss = cross_entropy(input=outputs, target=labels, reduction='mean')
    gradients = torch.autograd.grad(loss, var_obj.mean.values(), create_graph=not first_order)
    mean_inner = OrderedDict([
        (name, param - lr_inner * grad)
        for (name, param), grad in zip(var_obj.mean.items(), gradients)
    ])

    # taking remaining steps
    for _ in range(nstep_inner - 1):
        # zero grad
        for mu, cov in zip(var_obj.mean.values(), var_obj.covar.values()):
            if mu.grad is not None and cov.grad is not None:
                mu.grad.zero_()
                cov.grad.zero_()

        outputs = var_obj.model(inputs, mean=mean_inner, cov=var_obj.exp_covar(var_obj.covar))
        loss = cross_entropy(input=outputs, target=labels, reduction='mean')
        gradients = torch.autograd.grad(loss, mean_inner.values(), create_graph=not first_order)
        mean_inner = OrderedDict([
            (name, param - lr_inner * grad) for (name, param), grad in zip(mean_inner.items(), gradients)
        ])

    return mean_inner
