from collections import OrderedDict
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from datetime import datetime


def loadrb_onhydra(model, model_dict):
    checkpoint = OrderedDict(OrderedDict((k.replace('layer.0', 'block1'), v) for k, v in model_dict.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('layer.1', 'block2'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('layer.2', 'block3'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('block.', 'layer.'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('batchnorm_0', 'bn1'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('batchnorm_1', 'bn2'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('conv_0', 'conv1'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('conv_1', 'conv2'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('logits', 'fc'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('batchnorm', 'bn1'), v) for k, v in checkpoint.items()))
    checkpoint = OrderedDict(OrderedDict((k.replace('init_conv', 'conv1'), v) for k, v in checkpoint.items()))
    checkpoint_dict = OrderedDict(
        OrderedDict((k.replace('shortcut', 'convShortcut'), v) for k, v in checkpoint.items()))
    model.load_state_dict(checkpoint_dict)
    return model


def process_arg(args, arg):
    if arg in ['gpu', 'eval_sharpness', 'log', 'rewrite']:
        return ''
    if arg == 'adaptive':
        return ''
    if arg != 'model_path':
        return str(getattr(args, arg))
    # return args.model_path.split('/')[-1][:24].replace(' ', '_')
    return ''


def get_path(args, log_folder):
    name = '-'.join([process_arg(args, arg) for arg in list(filter(lambda x: x not in ['adaptive'], vars(args)))])
    name = str(datetime.now())[:-3].replace(' ', '_') + name
    if getattr(args, 'adaptive'):
        name += '-adaptive'
    path = f'{log_folder}/{name}.json'
    return path


def zero_grad(model):
    for p in model.parameters():
        if p.grad is not None:
            p.grad.zero_()


def compute_robust_err(batches, model, atk, loss_f=F.cross_entropy, n_batches=-1):
    n_wrong_classified, train_loss_sum, n_ex = 0, 0.0, 0
    if atk is None:
        with torch.no_grad():
            for i, (X, _, y, _, ln) in enumerate(batches):
                if n_batches != -1 and i > n_batches:  # limit to only n_batches
                    break
                X, y = X.cuda(), y.cuda()

                # print(X, X.shape)
                output = model(X)
                loss = loss_f(output, y)

                n_wrong_classified += (output.max(1)[1] != y).sum().item()
                train_loss_sum += loss.item() * y.size(0)
                n_ex += y.size(0)

    else:
        for i, (X, _, y, _, ln) in enumerate(batches):
            if n_batches != -1 and i > n_batches:  # limit to only n_batches
                break

            # load inputs on device
            X, y = X.cuda(), y.cuda()

            adv_X = atk(X, y)
            # print(X, X.shape)
            output = model(adv_X)
            loss = loss_f(output, y)

            n_wrong_classified += (output.max(1)[1] != y).sum().item()
            train_loss_sum += loss.item() * y.size(0)
            n_ex += y.size(0)

    err = n_wrong_classified / n_ex
    avg_loss = train_loss_sum / n_ex

    return err, avg_loss


def estimate_robust_loss_err(model, batches, atk, loss_f):
    err = 0
    loss = 0
    for i_batch, (x, _, y, _, _) in enumerate(batches):
        x, y = x.cuda(), y.cuda()
        adv_x = atk(x, y)
        curr_y = model(adv_x)
        loss += loss_f(curr_y, y)
        err += (curr_y.max(1)[1] != y).float().mean().item()

    return loss.item() / len(batches), err / len(batches)


def rate_act_func(k_score, k_min):
    k = torch.sigmoid(k_score)
    k = k * (1 - k_min)  # E.g. global_k = 0.1, Make layer k in range [0.0, 0.99]
    k = k + k_min  # Make layer k in range [0.01, 1.0]
    return k


from models.layers import SubnetConv, SubnetLinear


def set_prune_rate_model(model, args, device):
    shortcut_k = [1.0]
    conv_k = [1.0]
    fc_k = [1.0]
    for block_n, v in model.named_modules():

        if hasattr(v, "set_prune_rate"):
            v.set_prune_rate(args.k, args.k, 0.1, device)


def initialize_scaled_score(model, prune_reg='weight'):
    print(
        "Initialization relevance score proportional to weight magnitudes (OVERWRITING SOURCE NET SCORES) | Prune_Reg: {}".format(
            prune_reg)
    )
    for name, m in model.named_modules():
        if hasattr(m, "popup_scores"):
            n = nn.init._calculate_correct_fan(m.popup_scores, "fan_in")
            if prune_reg == 'weight':
                # Weight Pruning
                # """
                # Close to kaiming unifrom init
                m.popup_scores.data = (
                        math.sqrt(6 / n) * m.weight.data / torch.max(torch.abs(m.weight.data))
                )
                # """
            elif prune_reg == 'channel':
                # Channel Prune
                # """
                reshaped_weights = torch.sum(torch.abs(m.weight.data.reshape(m.weight.data.shape[1], -1)), dim=1)
                if type(m) == SubnetConv:
                    channel_popup_scores = (
                                math.sqrt(6 / n) * reshaped_weights / torch.max(torch.abs(reshaped_weights))).reshape(1,
                                                                                                                      m.weight.data.shape[
                                                                                                                          1],
                                                                                                                      1,
                                                                                                                      1)
                else:
                    channel_popup_scores = (
                            math.sqrt(6 / n) * reshaped_weights / torch.max(torch.abs(reshaped_weights))).reshape(
                        1, m.weight.data.shape[1])
                m.popup_scores.data = channel_popup_scores
                # """
            else:
                raise NameError('Please check prune_reg, current "{}" is not in [weight, channel] !'.format(prune_reg))
