import os
import json
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt

from torch import optim
from sklearn.neighbors import KernelDensity
from scipy.spatial.distance import cdist
from scipy.stats import multivariate_normal

class LinfPGDAttack(object):
    def __init__(self, model, epsilon, num_steps, step_size):
        self.model = model
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        x = x_natural.detach()
        x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, x_natural - self.epsilon), x_natural + self.epsilon)
            x = torch.clamp(x, 0, 1)
        return x

def get_grad_diff(args, model, base_loader, basic=False):
    adversary = LinfPGDAttack(model, epsilon=0.0314, num_steps=7, step_size=0.00784)
    loss_func = nn.CrossEntropyLoss(reduction='sum')
    model.train()
    grads = []

    for i, (images, labels) in enumerate(base_loader):
        images, labels = images.to(args.device), labels.to(args.device)

        if basic:
            result_z = model(images)
        else:
            adv = adversary.perturb(images, labels)
            result_z = model(adv)

        loss_z = loss_func(result_z, labels)
        loss_diff = -loss_z

        differentiable_params = [p for p in model.parameters() if p.requires_grad]
        gradients = torch.autograd.grad(loss_diff, differentiable_params)
        grads.append(gradients)

    # add all grads from batch
    grads = list(zip(*grads))
    for i in range(len(grads)):
        tmp = grads[i][0]
        for j in range(1, len(grads[i])):
            tmp = torch.add(tmp, grads[i][j])
        grads[i] = tmp

    return grads

def hvp(model, x, y, v):
    """ Hessian vector product. """
    grad_L = get_gradients(model, x, y, v)
    # v_dot_L = [v_i * grad_i for v_i, grad_i in zip(v, grad_L)]
    differentiable_params = [p for p in model.parameters() if p.requires_grad]
    v_dot_L = torch.sum(torch.stack([torch.sum(grad_i * v_i) for grad_i, v_i in zip(grad_L, v)]))

    hvp = list(torch.autograd.grad(v_dot_L, differentiable_params, retain_graph=True))
    return hvp

def get_gradients(model, x, y, v):
    """ Calculate dL/dW (x, y) """
    loss_func = nn.CrossEntropyLoss(reduction='sum')

    adversary = LinfPGDAttack(model, epsilon=0.0314, num_steps=7, step_size=0.00784)
    adv = adversary.perturb(x, y)
    result = model(adv)
    loss = loss_func(result, y)

    # result = model(x)
    # loss = loss_func(result, y)

    differentiable_params = [p for p in model.parameters() if p.requires_grad]
    grads = torch.autograd.grad(loss, differentiable_params, retain_graph=True, create_graph=True,
                                only_inputs=True)
    return grads

def get_inv_hvp(args, model, data_loader, v, damping=0.1, scale=200, rounds=1):
    estimate = None
    for r in range(rounds):
        u = [torch.zeros_like(v_i) for v_i in v]
        for i, (images, labels) in enumerate(data_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            batch_hvp = hvp(model, images, labels, v)

            new_estimate = [a + (1 - damping) * b - c / scale for (a, b, c) in zip(v, u, batch_hvp)]

        res_upscaled = [r / scale for r in new_estimate]
        if estimate is None:
            estimate = [r / rounds for r in res_upscaled]
        else:
            for j in range(len(estimate)):
                estimate[j] += res_upscaled[j] / rounds
    return estimate

def unrolling_sgd(args, model, base_loader, train_loader):
    loss_func = nn.CrossEntropyLoss()
    M = copy.deepcopy(model)
    M.train()

    if args.opt == 'SGD':
        optimizer_M = optim.SGD(M.parameters(), lr=args.un_lr, momentum=0.9, weight_decay=5e-4)
    elif args.opt == 'Adam':
        optimizer_M = optim.Adam(M.parameters(), lr=args.un_lr, weight_decay=0.0002)

    adversary_M = LinfPGDAttack(M, epsilon=0.0314, num_steps=7, step_size=0.00784)

    for ep in range(0, args.finetune_epoch):
        print('Fine-tune epoch =', ep)
        for main_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            # only update M:
            adv = adversary_M.perturb(inputs, targets)
            optimizer_M.zero_grad()
            adv_outputs_M = M(adv)
            loss_M = loss_func(adv_outputs_M, targets)
            loss_M.backward()
            optimizer_M.step()

    M_unlearned = copy.deepcopy(M)
    M_unlearned.train()

    adversary = LinfPGDAttack(model, epsilon=0.0314, num_steps=7, step_size=0.00784)
    for i, (inputs, targets) in enumerate(base_loader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)

        adv = adversary.perturb(inputs, targets)
        adv_output_pre = model(adv)
        loss_unl = loss_func(adv_output_pre, targets)
        loss_unl.backward(retain_graph=True)
        grads = torch.autograd.grad(loss_unl, [param for param in model.parameters()], create_graph=True)

        old_params = {}
        for j, (name, params) in enumerate(M_unlearned.named_parameters()):
            old_params[name] = params.clone()
            old_params[name] += args.finetune_epoch * args.un_lr * (grads[j])
        for name, params in M_unlearned.named_parameters():
            params.data.copy_(old_params[name])

    return M_unlearned

def get_kd(data, bandwidth=0.2):
    S = data[1:, ]
    x = data[:1, ]

    kde = KernelDensity(kernel='gaussian', bandwidth=0.2)
    kde.fit(S)

    log_density = kde.score_samples(x)[0]
    return np.exp(log_density)

def get_lid(data, k=20):
    x = data[:1, ]
    data = np.asarray(data, dtype=np.float32)
    x = np.asarray(x, dtype=np.float32)
    if x.ndim == 1:
        x = x.reshape((-1, x.shape[0]))

    k = min(k, len(data) - 1)
    f = lambda v: - k / np.sum(np.log(v / v[-1]))
    a = cdist(x, data)
    a = np.apply_along_axis(np.sort, axis=1, arr=a)[:, 1:k + 1]
    a = np.apply_along_axis(f, axis=1, arr=a)
    return a[0]

def get_unique(input_ids):
    unique_numbers = set()
    for row in input_ids:
        for num in row:
            unique_numbers.add(num)
    unique_numbers = list(unique_numbers)
    return unique_numbers