import torch
import torchattacks
import utils
import copy
import math
from functools import partial
from utils import rate_act_func
from models.layers import GetSubnet


def zero_init_delta_dict(delta_dict, device, rho):
    for param in delta_dict:
        delta_dict[param] = torch.zeros_like(param).to(device)

    delta_norm = torch.cat([delta_param.flatten() for delta_param in delta_dict.values()]).norm()
    for param in delta_dict:
        delta_dict[param] *= rho / delta_norm

    return delta_dict


def random_init_on_sphere_delta_dict(delta_dict, device, rho, **unused_kwargs):
    for param in delta_dict:
        delta_dict[param] = torch.randn_like(param).to(device)

    delta_norm = torch.cat([delta_param.flatten() for delta_param in delta_dict.values()]).norm()
    for param in delta_dict:
        delta_dict[param] *= rho / delta_norm

    return delta_dict


def random_gaussian_dict(delta_dict, device, rho):
    n_el = 0
    for param_name, p in delta_dict.items():
        delta_dict[param_name] = torch.randn_like(p).to(device)
        n_el += p.numel()

    for param_name in delta_dict.keys():
        delta_dict[param_name] *= rho / (n_el ** .5)

    return delta_dict


def random_init_lw(delta_dict, rho, orig_param_dict, device, norm='l2', adaptive=False):
    assert norm in ['l2', 'linf'], f'Unknown perturbation model {norm}.'

    for param in delta_dict:
        if norm == 'l2':
            delta_dict[param] = torch.randn_like(delta_dict[param]).to(device)
        elif norm == 'linf':
            delta_dict[param] = (2 * torch.rand_like(delta_dict[param], device='cuda') - 1)

    for param in delta_dict:
        param_norm_curr = orig_param_dict[param].abs() if adaptive else 1
        delta_dict[param] *= rho * param_norm_curr

    return delta_dict


def weight_ascent_step_momentum(
        model, x, y, loss_f, orig_param_dict, delta_dict, device, prev_delta_dict,
        step_size, rho, atk=None, momentum=0.75, layer_name_pattern='all', no_grad_norm=False,
        verbose=False, adaptive=False, eot_iter=0, eot_sigma=1., norm='linf', args=None):
    """
    model:              w[k]
    orig_param_dict:    w[0]
    delta_dict:         w[k]-w[0]
    prev_delta_dict:    w[k-1]-w[0]
    1-alpha:            momentum coefficient

    -----------------------------------------------
    z[k+1] = P(w[k] + step_size*Grad F(w[k]))
    w[k+1] = P(w[k] + alpha*(z[k+1]-w[k])+(1-alpha)*(w[k]-w[k-1]))

    versions
    - old           -> L2-bound on all parameters, layer-wise rescaling
    - lw_l2_indep   -> L2-bound on each layer of rho * norm of the layer (if adaptive)
    """

    delta_dict_backup = {param: delta_dict[param].clone() for param in
                         delta_dict}  # copy of perturbation dictionary  (w[k]-w[0])
    # curr_params = {name_p: p.clone() for name_p, p in model.named_parameters()}

    # TODO: What is this? Check if removing it something changes, because I think it makes no sense
    adv_x = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)(x, y)
    for param in model.parameters():
        param.requires_grad = True
    output = model(adv_x)
    # print(output.data)
    obj = loss_f(output, y)
    obj.backward()
    # # Average gradients in a neighborhood of current model.
    # for _ in range(eot_iter):
    #     for n, p in model.named_parameters():
    #         p.data = curr_params[n] + torch.randn_like(p) * eot_sigma
    #     obj = loss_f(model(x), y) / (eot_iter + 1)
    #     obj.backward()

    with torch.no_grad():
        if norm == 'linf' and (args.task_mode == 'harp_prune' or args.task_mode == 'harp_finetune') and not args.dense:
            for param_name, param in model.named_parameters():
                if ('popup_scores' in param_name):
                    if ('bn' not in param_name) and ('shortcut' not in param_name):
                        grad_sign_curr = param.grad.sign()
                        delta_dict[param] += (step_size * grad_sign_curr * (
                            1 if not adaptive else orig_param_dict[param].abs()))
                    else:
                        grad_sign_curr = param.grad.sign()
                        delta_dict[param] += (step_size * grad_sign_curr * (
                            1 if not adaptive else orig_param_dict[param].abs()))

        elif norm == 'linf' and (
                args.task_mode == 'score_prune' or args.task_mode == 'score_finetune' or args.task_mode == 'harp_finetune_lwm') and not args.dense:
            for param_name, param in model.named_parameters():
                if ('popup_scores'  in param_name):
                    if ('bn' not in param_name) and ('shortcut' not in param_name):
                        grad_sign_curr = param.grad.sign()
                        delta_dict[param] += (step_size * grad_sign_curr * (
                            1 if not adaptive else orig_param_dict[param].abs()))
                    else:
                        grad_sign_curr = param.grad.sign()
                        delta_dict[param] += (step_size * grad_sign_curr * (
                            1 if not adaptive else orig_param_dict[param].abs()))

        else:
            raise ValueError('wrong norm')
    utils.zero_grad(model)

    with torch.no_grad():
        # Projection step I, rescaling perturbations
        if norm == 'l2':  # Project onto L2-ball of radius rho (* ||w|| if adaptive)
            raise NotImplementedError

        if args.dense:
            for (param_name, _), param in zip(model.named_parameters(), delta_dict):
                if 'popup_scores' in param_name:
                    if ('bn' not in param_name) and ('shortcut' not in param_name):
                        param_curr = orig_param_dict[param].abs() if adaptive else torch.ones_like(
                            orig_param_dict[param])
                        delta_dict[param] = torch.max(
                            torch.min(delta_dict[param], param_curr * rho), -1. * param_curr * rho)

        elif norm == 'linf' and (args.task_mode == 'harp_prune' or args.task_mode == 'harp_finetune') and not args.dense:
            # Project onto Linf-ball of radius rho (* |w| if adaptive)
            for (param_name, _), param in zip(model.named_parameters(), delta_dict):
                if ('popup_scores' in param_name):
                    if ('bn' not in param_name) and ('shortcut' not in param_name):
                        param_curr = orig_param_dict[param].abs() if adaptive else torch.ones_like(
                            orig_param_dict[param])
                        delta_dict[param] = torch.max(
                            torch.min(delta_dict[param], param_curr * rho), -1. * param_curr * rho)
                    else:
                        param_curr = orig_param_dict[param].abs() if adaptive else torch.ones_like(
                            orig_param_dict[param])
                        delta_dict[param] = torch.max(
                            torch.min(delta_dict[param], param_curr * rho), -1. * param_curr * rho)

        elif norm == 'linf' and (
                args.task_mode == 'score_prune' or args.task_mode == 'score_finetune' or args.task_mode == 'harp_finetune_lwm') and not args.dense:
            # Project onto Linf-ball of radius rho (* |w| if adaptive)
            for (param_name, _), param in zip(model.named_parameters(), delta_dict):
                if 'popup_scores' in param_name:
                    if ('bn' not in param_name) and ('shortcut' not in param_name):
                        param_curr = orig_param_dict[param].abs() if adaptive else torch.ones_like(
                            orig_param_dict[param])
                        delta_dict[param] = torch.max(
                            torch.min(delta_dict[param], param_curr * rho), -1. * param_curr * rho)
                    else:
                        param_curr = orig_param_dict[param].abs() if adaptive else torch.ones_like(
                            orig_param_dict[param])
                        delta_dict[param] = torch.max(
                            torch.min(delta_dict[param], param_curr * rho), -1. * param_curr * rho)

        else:
            raise ValueError('wrong norm')

        # Average perturbations (apply momentum)

        # IF SCORE
        for param_name, param in model.named_parameters():
            if 'score' not in param_name:
                delta_dict[param] = (
                        momentum * delta_dict[param] + (1 - momentum) * prev_delta_dict[param])

        # Applying perturbations
        for param_name, param in model.named_parameters():
            if 'score' not in param_name:
                param.data = orig_param_dict[param] + delta_dict[param]

        for param_name, param in model.named_parameters():
            if 'score' not in param_name:
                prev_delta_dict[param] = delta_dict_backup[param]

    return delta_dict, prev_delta_dict


def eval_APGD_Rsharpness(
        model, batches, loss_f, train_err, train_loss, device, atk=None, rho=0.01,
        step_size_mult=1, n_iters=200, layer_name_pattern='all',
        n_restarts=1, min_update_ratio=0.75, rand_init=True,
        no_grad_norm=False, verbose=False, return_output=False,
        adaptive=False, version='default', norm='linf', args=None, **kwargs,
):
    """Computes worst-case sharpness for every batch independently, and returns
    the average values.
    """

    assert n_restarts == 1 or rand_init, 'Restarts need random init.'
    del train_err
    del train_loss
    gradient_step_kwargs = kwargs.get('gradient_step_kwargs', {})

    init_fn = partial(random_init_lw, norm=norm, adaptive=adaptive),

    def get_loss_and_err(model, loss_fn, x, y):
        """Compute loss and class. error on a single batch."""
        adv_x = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)(x, y)
        output = model(adv_x)
        loss = loss_fn(output, y)
        err = (output.max(1)[1] != y).float().mean()
        return loss.cpu().item(), err.cpu().item()

    # cloning the model
    orig_model_state_dict = copy.deepcopy(model.state_dict())
    orig_param_dict = {param: param.clone() for param in model.parameters()}

    n_batches, delta_norm = 0, 0.
    avg_loss, avg_err, avg_init_loss, avg_init_err = 0., 0., 0., 0.
    output = ""

    if version == 'default':
        p = [0, 0.22]
        w = [0, math.ceil(n_iters * 0.22)]

        while w[-1] < n_iters and w[-1] != w[-2]:
            p.append(p[-1] + max(p[-1] - p[-2] - 0.03, 0.06))
            w.append(math.ceil(p[-1] * n_iters))

        w = w[1:]  # No check needed at the first iteration.
        print(w)
        step_size_scaler = .5
    else:
        raise ValueError(f'Unknown version {version}')

    for i_batch, (x, _, y, _, _) in enumerate(batches):
        x, y = x.to(device), y.to(device)

        # Robust Loss and err on the unperturbed model.
        init_loss, init_err = get_loss_and_err(model, loss_f, x, y)

        # Accumulate over batches.
        avg_init_loss += init_loss
        avg_init_err += init_err

        worst_loss_over_restarts = init_loss
        worst_err_over_restarts = init_err
        worst_delta_norm_over_restarts = 0.

        for restart in range(n_restarts):

            if rand_init:
                delta_dict = {param: torch.zeros_like(param) for n, param in model.named_parameters()}
                delta_dict = init_fn(delta_dict, rho, orig_param_dict=orig_param_dict)
                for param in model.parameters():
                    param.data += delta_dict[param]
            else:
                # Creates a dictionary with key=param, value=0s_likeparam
                delta_dict = {param: torch.zeros_like(param) for n, param in model.named_parameters()}

                # this is initialized as the first instance, but is then modified throughout training.
            prev_delta_dict = {param: delta_dict[param].clone() for param in delta_dict}
            worst_model_dict = copy.deepcopy(model.state_dict())

            prev_worst_loss, worst_loss = init_loss, init_loss
            worst_err = init_err
            step_size, prev_step_size = 2 * rho * step_size_mult, 2 * rho * step_size_mult
            prev_cp = 0
            num_of_updates = 0

            for i in range(n_iters):

                delta_dict, prev_delta_dict = weight_ascent_step_momentum(
                    model, x, y, loss_f, orig_param_dict, delta_dict, device, prev_delta_dict,
                    step_size, rho, atk=atk, momentum=0.75, layer_name_pattern=layer_name_pattern,
                    no_grad_norm=no_grad_norm, verbose=False, adaptive=adaptive,
                    norm=norm, args=args, **gradient_step_kwargs)

                # with torch.no_grad():
                curr_loss, curr_err = get_loss_and_err(model, loss_f, x, y)
                delta_norm_total = torch.cat(
                    [delta_param.flatten() for delta_param in delta_dict.values()]).norm().item()

                if curr_loss > worst_loss:
                    worst_loss = curr_loss
                    worst_err = curr_err
                    worst_model_dict = copy.deepcopy(model.state_dict())
                    worst_delta_norm = delta_norm_total
                    num_of_updates += 1

                if i in w:
                    cond1 = num_of_updates < (min_update_ratio * (i - prev_cp))
                    cond2 = (prev_step_size == step_size) and (prev_worst_loss == worst_loss)
                    prev_step_size, prev_worst_loss, prev_cp = step_size, worst_loss, i
                    num_of_updates = 0

                    if cond1 or cond2:
                        print('Reducing step size.')
                        step_size *= step_size_scaler
                        model.load_state_dict(worst_model_dict)

                str_to_log = '[batch={} restart={} iter={}] Sharpness: obj={:.4f}, err={:.2%}, delta_norm={:.5f} (step={:.5f})'.format(
                    i_batch + 1, restart + 1, i + 1, curr_loss - init_loss, curr_err - init_err, delta_norm_total,
                    step_size)
                if verbose:
                    print(str_to_log)
                output += str_to_log + '\n'

            # Keep the best values over restarts.
            if worst_loss > worst_loss_over_restarts:
                worst_loss_over_restarts = worst_loss
                worst_err_over_restarts = worst_err
                worst_delta_norm_over_restarts = worst_delta_norm

            # Reload the unperturbed model for the next restart or batch.
            model.load_state_dict(orig_model_state_dict)

            if verbose:
                print('')

        # Accumulate over batches.
        n_batches += 1
        avg_loss += worst_loss_over_restarts
        avg_err += worst_err_over_restarts
        delta_norm = max(delta_norm, worst_delta_norm_over_restarts)

        if verbose:
            print('')

    vals = (
        (avg_loss - avg_init_loss) / n_batches,
        (avg_err - avg_init_err) / n_batches,
        delta_norm,
    )
    if return_output:
        vals += (output,)

    return vals


