from collections import OrderedDict

import torch
import torch.nn.functional as F
from torchmeta.modules import MetaModule

import torch.nn as nn

def trades_criterion(input_clean, input_adv, target, advw=1.0):
    return F.cross_entropy(input_clean, target) + advw * F.kl_div(F.log_softmax(input_adv, dim=1), F.softmax(input_clean, dim=1)).mean()

def maml_inner_adapt(model, criterion, inputs, targets, step_size, num_steps,
                     first_order=False, params=None, inner_update_type='both', trades=False, adv_inputs=None):

    """ inner gradient step """
    for step_inner in range(num_steps):
        outputs_train = model(inputs, params=params)
        if trades:
            adv_outputs_train = model(adv_inputs, params=params)
            loss = trades_criterion(outputs_train, adv_outputs_train, targets)
        else:
            loss = criterion(outputs_train, targets)

        model.zero_grad()
        params = gradient_update_parameters(
            model, loss, params=params, step_size=step_size, first_order=first_order, inner_update_type=inner_update_type
        )

    return params, loss


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def maximize_inner_adapt(model, criterion, inputs, targets, target_param,pr_type, step_size, num_steps,
                     first_order=False, params=None, inner_update_type='both', trades=False, adv_inputs=None):
    mse_loss = nn.MSELoss()
    """ inner gradient step """
    for step_inner in range(num_steps):
        outputs_train = model(inputs, params=params)
        if trades:
            adv_outputs_train = model(adv_inputs, params=params)
            loss = trades_criterion(outputs_train, adv_outputs_train, targets)
        else:
            loss = criterion(outputs_train, targets)

        orig_params = OrderedDict(model.meta_named_parameters())
        for (name1, param1), (name2, param2) in zip(orig_params.items(), target_param.items()):
            if pr_type=='mse':
                loss += -1.0 * mse_loss(param1, param2.detach())
            elif pr_type == 'corr':
                In = param1.size(0)
                out = param1.view(In, -1).size(1)
                bn = nn.BatchNorm1d(out, affine=False).cuda()
                
                # empirical cross-correlation matrix
                c = bn(param1.view(In, -1)).T @ bn(param2.view(In,-1))

                # sum the cross-correlation matrix between all gpus
                c.div_(In)
                
                on_diag = torch.diagonal(c).pow_(2).sum()
                off_diag = off_diagonal(c).add_(-1).pow_(2).sum()
                loss += on_diag + off_diag
        model.zero_grad()
        params = gradient_update_parameters(
            model, loss, params=params, step_size=step_size, first_order=first_order, inner_update_type=inner_update_type
        )

    return params, loss

def gradient_update_parameters(model,
                               loss,
                               params=None,
                               step_size=0.5,
                               first_order=False,
                               inner_update_type='both'):
    if not isinstance(model, MetaModule):
        raise ValueError('The model must be an instance of `torchmeta.modules.'
                         'MetaModule`, got `{0}`'.format(type(model)))
    if params is None:
        params = OrderedDict(model.meta_named_parameters())

    grads = torch.autograd.grad(loss,
                                params.values(),
                                create_graph=not first_order,
                                allow_unused=True)  # this is for anil implementation

    updated_params = OrderedDict()
    
    if isinstance(step_size, (dict, OrderedDict)):
        for (name, param), grad in zip(params.items(), grads):
            if grad is None:
                grad = 0.
            if inner_update_type =='linear_only':
                if 'classifier' in name:
                    updated_params[name] = param - step_size[name] * grad
                else:
                    updated_params[name] = param
            elif inner_update_type == 'encoder_only':
                if 'classifier' in name:
                    updated_params[name] = param 
                else:
                    updated_params[name] = param - step_size[name] * grad
            else:
                updated_params[name] = param - step_size[name] * grad
    else:
        for (name, param), grad in zip(params.items(), grads):
            if grad is None:
                grad = 0.
                
            if inner_update_type =='linear_only':
                if 'classifier' in name:
                    updated_params[name] = param - step_size * grad
                else:
                    updated_params[name] = param
            elif inner_update_type == 'encoder_only':
                if 'classifier' in name:
                    updated_params[name] = param 
                else:
                    updated_params[name] = param - step_size * grad
            else:
                updated_params[name] = param - step_size * grad

    return updated_params
