import numpy as np
import torch
import torch.nn.functional as F

from attack.pgd_attack import PgdAttack
from attack.log_barrier_attack import LogBarrierAttacks, compute_consts
from attack.pgd_attack_restart import attack_pgd_restart
from utils.context import ctx_noparamgrad
from utils.math_utils import l2_norm_batch as l2b
import pdb


def _attack_loss(predictions, labels):
    return -torch.nn.CrossEntropyLoss(reduction='sum')(predictions, labels)

def g_func(losses, l):
    return torch.exp(losses/l)


CE_LOSS = torch.nn.CrossEntropyLoss(reduction='none')
KL_LOSS = torch.nn.KLDivLoss(reduction='none')


class BatTrainer:
    def __init__(self, args, log):
        self.args = args
        self.steps = args.attack_step
        self.eps = args.attack_eps
        self.attack_lr = args.attack_lr
        self.attack_rs = args.attack_rs
        if self.args.lmbda != 0.0:
            self.lmbda = self.args.lmbda
        else:
            self.lmbda = 1. / self.attack_lr

        self.constraint_type = np.inf
        self.log = log
        self.mode = args.mode
        self.z_init_non_sign_attack_lr = 5000. / 255

        self.mu = args.mu #0.01
        self.kl_coef = args.kl_coef #10.

        self.g = torch.zeros(1, device='cuda:0')
        #self.grad_g = torch.zeros(1, device='cuda:0')
        self.use_scgd = True

    def test_sa(self, model, data, labels):
        model.eval()
        with torch.no_grad():
            predictions_sa = model(data)
            correct = (torch.argmax(predictions_sa.data, 1) == labels).sum().cpu().item()
        return correct

    def get_input_grad(self, model, X, y, eps, delta_init='none', backprop=True):
        # actually provides grad wrt delta

        if delta_init == 'none':
            delta = torch.zeros_like(X, requires_grad=True)
        elif delta_init == 'random_uniform':
            delta = torch.zeros_like(X).uniform_(-eps, eps).requires_grad_(True)
        elif delta_init == 'random_corner':
            delta = torch.zeros_like(X).uniform_(-eps, eps).requires_grad_(True)
            delta = eps * torch.sign(delta)
        else:
            raise ValueError('wrong delta init')

        output = model(X + delta)
        loss = F.cross_entropy(output, y)
        grad = torch.autograd.grad(loss, delta, create_graph=True if backprop else False)[0]
        if not backprop:
            grad, delta = grad.detach(), delta.detach()
        return grad

    def get_perturbation_init(self, model, x, y, eps, device, method, z_init_detach=True):
        z_init = torch.clamp(
            x + torch.FloatTensor(x.shape).uniform_(-eps, eps).to(device),
            min=0, max=1
        ) - x
        z_init.requires_grad_(True)

        retain_graph = not z_init_detach
        pgd_attack_lr = 1.25 * eps
        fgsm_attack_lr = eps

        if method == "random":
            z = z_init

        elif method == "fgsm":
            model.clear_grad()
            model.with_grad()
            z_init = torch.zeros_like(x).requires_grad_(True)
            attack_loss_first = _attack_loss(model(x + z_init), y)
            grad_attack_loss_delta_first = \
                torch.autograd.grad(attack_loss_first, z_init, retain_graph=retain_graph, create_graph=retain_graph)[0]
            z = z_init - fgsm_attack_lr * torch.sign(grad_attack_loss_delta_first)
            z = torch.clamp(x + z, min=0, max=1) - x

        elif method == "pgd":
            model.clear_grad()
            model.with_grad()
            attack_loss_first = _attack_loss(model(x + z_init), y)
            grad_attack_loss_delta_first = \
                torch.autograd.grad(attack_loss_first, z_init, retain_graph=retain_graph, create_graph=retain_graph)[0]
            z = z_init - pgd_attack_lr * torch.sign(grad_attack_loss_delta_first)
            z = torch.clamp(z, min=-eps, max=eps)
            z = torch.clamp(x + z, min=0, max=1) - x

        elif method == "ns-pgd":
            model.clear_grad()
            model.with_grad()
            attack_loss_first = _attack_loss(model(x + z_init), y)
            grad_attack_loss_delta_first = \
                torch.autograd.grad(attack_loss_first, z_init, retain_graph=retain_graph, create_graph=retain_graph)[0]
            z = z_init - self.z_init_non_sign_attack_lr * grad_attack_loss_delta_first
            z = torch.clamp(z, min=-eps, max=eps)
            z = torch.clamp(x + z, min=0, max=1) - x

        elif method == "ns-gd":
            model.clear_grad()
            model.with_grad()
            attack_loss_first = _attack_loss(model(x + z_init), y)
            grad_attack_loss_delta_first = \
                torch.autograd.grad(attack_loss_first, z_init, retain_graph=retain_graph, create_graph=retain_graph)[0]
            z = z_init - self.z_init_non_sign_attack_lr * grad_attack_loss_delta_first

        elif method == "ns-pgd-zero":
            z_init = torch.zeros_like(x).requires_grad_(True)
            model.clear_grad()
            model.with_grad()
            attack_loss_first = _attack_loss(model(x + z_init), y)
            grad_attack_loss_delta_first = \
                torch.autograd.grad(attack_loss_first, z_init, retain_graph=retain_graph, create_graph=retain_graph)[0]
            z = z_init - self.z_init_non_sign_attack_lr * grad_attack_loss_delta_first
            z = torch.clamp(z, min=-eps, max=eps)
            z = torch.clamp(x + z, min=0, max=1) - x

        else:
            raise NotImplementedError

        if z_init_detach:
            return z.detach()
        else:
            return z

    def train(self, model, train_dl, opt, loss_func, device, epoch, scheduler=None, kl_coef=10., wandb=None):

        adversary_train = PgdAttack(
            model, loss_fn=loss_func, eps=self.eps, steps=self.steps,
            eps_lr=self.attack_lr, ord=self.constraint_type,
            rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False,
            regular=0, sign=True
        )

        lb_attacker = LogBarrierAttacks(model, eps=self.eps, alpha=self.attack_lr, steps=self.steps,
                 random_start=True,
                 #attack_loss=nce,
                 mu=self.mu,
                 clip_min=0.,
                 clip_max=1.,
                 epsilon=1e-5) # 1e-4

        model.train()
        training_loss = torch.tensor([0.])
        train_sa = torch.tensor([0.])
        train_ra = torch.tensor([0.])

        total = 0

        ds = len(train_dl)

        for i, (data, labels) in enumerate(train_dl):
            data = data.type(torch.FloatTensor)
            data = data.to(device)
            labels = labels.to(device)
            real_batch = data.shape[0]
            channels = data.shape[1]
            image_size = data.shape[2]
            total += real_batch

            # Record SA along with each batch
            train_sa += self.test_sa(model, data, labels)

            model.train()

            if self.mode == "ours":
                #print('running our method...')
                model.clear_grad()
                model.train()


                delta_star = lb_attacker.attack(data, labels) - data
                ####################################################################################

                model.clear_grad()
                model.with_grad()
                delta = delta_star.clone().detach().requires_grad_(True)
                predictions = model(data + delta)
                losses = CE_LOSS(predictions, labels) # result is not reduced

                if kl_coef < 10: # to avoid large values inside exp func for small kl coef
                    losses = losses - losses.max().detach() # /2. is good for kl_coef = 1.

                g_losses = g_func(losses, l=kl_coef)
                cost = g_losses.sum()
                #pdb.set_trace()

                grad_delta = torch.autograd.grad(cost, delta, retain_graph=True, create_graph=False)[0]
                direct_grads = torch.autograd.grad(cost, list(model.parameters()), retain_graph=False, create_graph=False)

                cup, clow = compute_consts(data, delta, self.eps)
                ct = torch.cat([cup, clow], 1) # all constraints
                dd = cup.shape[1]
                ct = ct**2
                c_inv = ct[:, :dd] * ct[:, dd:] / (ct[:, :dd] + ct[:, dd:])
                c_inv_grad_delta = c_inv.view(delta.shape) * grad_delta / self.mu
                c_inv_grad_delta = c_inv_grad_delta.detach() # should be detached

                model.clear_grad()
                model.with_grad()
                delta = delta_star.clone().detach().requires_grad_(True)
                outputs = model(data + delta)
                cost = _attack_loss(outputs, labels)
                grad_attack_loss = torch.autograd.grad(cost, delta, create_graph=True)[0]
                jv_prods = torch.autograd.grad(outputs=grad_attack_loss,
                                                      inputs=list(model.parameters()),
                                                      grad_outputs=c_inv_grad_delta) #
                indirect_grads = [-jvp for jvp in jv_prods]


                if not self.use_scgd:
                    total_grads = [1. * kl_coef * (gd + 1. * gi) / g_losses.sum().detach() for (gd, gi) in zip(direct_grads, indirect_grads)]
                else:
                    beta = 0.1 # 0.5 is best, 0.1 is also good

                    self.g = (1-beta) * self.g + beta * g_losses.sum().detach()/real_batch
                    self.g = self.g.detach()
                    total_grads = [(kl_coef / self.g) * (gd +  gi) / real_batch for (gd, gi) in zip(direct_grads, indirect_grads)]



                with torch.no_grad():
                    for p, g in zip(model.parameters(), total_grads):
                        p.grad = g.data #.clone().detach()



                opt.step()

                train_loss = loss_func(outputs, labels).detach() / real_batch

            elif self.mode == "ours++":
                model.clear_grad()
                model.train()

                delta_star = lb_attacker.attack(data, labels) - data

                delta_star = torch.clamp(delta_star, min=-self.eps + 1e-6, max=self.eps - 1e-6)
                delta_star = torch.clamp(delta_star + data, min=0. + 1e-6, max=1. - 1e-6) - data
                ####################################################

                model.clear_grad()
                model.with_grad()
                delta = delta_star.clone().detach().requires_grad_(True)
                predictions = model(data + delta)
                losses = CE_LOSS(predictions, labels)  # result is not reduced
                g_losses = g_func(losses.detach() - losses.max().detach(), l=kl_coef)
                weights = g_losses / g_losses.sum()


                if epoch<1: #20, if forcing weights to be all equal
                    weights = torch.ones_like(weights) / real_batch



                cost = torch.sum(weights * losses)
                # pdb.set_trace()

                grad_delta = torch.autograd.grad(cost, delta, retain_graph=True, create_graph=False)[0]
                direct_grads = torch.autograd.grad(cost, list(model.parameters()), retain_graph=False, create_graph=False)

                cup, clow = compute_consts(data, delta, self.eps)
                ct = torch.cat([cup, clow], 1)  # all constraints
                dd = cup.shape[1]
                ct = ct ** 2
                c_inv = ct[:, :dd] * ct[:, dd:] / (ct[:, :dd] + ct[:, dd:])
                c_inv_grad_delta = c_inv.view(delta.shape) * grad_delta / self.mu
                c_inv_grad_delta = c_inv_grad_delta.detach()

                model.clear_grad()
                model.with_grad()
                delta = delta_star.clone().detach().requires_grad_(True)
                outputs = model(data + delta)
                cost = _attack_loss(outputs, labels)
                grad_attack_loss = torch.autograd.grad(cost, delta, create_graph=True)[0]
                jv_prods = torch.autograd.grad(outputs=grad_attack_loss,
                                               inputs=list(model.parameters()),
                                               grad_outputs=c_inv_grad_delta)  #
                indirect_grads = [-jvp for jvp in jv_prods]  #


                total_grads = [gd +  gi for (gd, gi) in zip(direct_grads, indirect_grads)]


                with torch.no_grad():
                    for p, g in zip(model.parameters(), total_grads):
                        p.grad = g.data  # .clone().detach()

                # pdb.set_trace()

                opt.step()

                train_loss = loss_func(outputs, labels).detach() / real_batch

            elif self.mode == "fast_at":
                if self.steps == 0:
                    delta_star = torch.zeros_like(data).to(data) # to(data)????? lol
                else:

                    #
                    model.train()
                    opt.zero_grad()

                    delta_init = self.get_perturbation_init(model, data, labels, self.eps, device, "random") # always using random init

                    with ctx_noparamgrad(model):
                        delta_star = adversary_train.perturb(data, labels, delta_init=delta_init) - data

                delta_star.requires_grad = False

                # Update model with perturbed data
                model.clear_grad()
                model.with_grad()
                predictions = model(data + delta_star)
                train_loss = loss_func(predictions, labels) / real_batch
                train_loss.backward()
                opt.step()

            elif self.mode == "pgd":
                if self.steps == 0:
                    delta_star = torch.zeros_like(data).to(data)
                else:
                    model.train()
                    opt.zero_grad()

                    delta_init = self.get_perturbation_init(model=model, x=data, y=labels, eps=self.eps, device=device,
                                                            method="random")

                    with ctx_noparamgrad(model):
                        delta_star = adversary_train.perturb(data, labels, delta_init=delta_init) - data

                delta_star.requires_grad = False

                # Update model with perturbed data
                model.clear_grad()
                model.with_grad()
                predictions = model(data + delta_star)
                train_loss = loss_func(predictions, labels) / real_batch
                train_loss.backward()
                opt.step()

            elif self.mode == "fast_at_ga":

                double_bp = True if self.args.ga_coef > 0 else False

                X, y = data.to(device), labels.to(device)
                delta = torch.zeros_like(X, requires_grad=True)

                X_adv = torch.clamp(X + delta, 0, 1)
                output = model(X_adv)
                loss = F.cross_entropy(output, y)
                grad = torch.autograd.grad(loss, delta, create_graph=True if double_bp else False)[0]
                grad = grad.detach()

                argmax_delta = self.eps * torch.sign(grad)

                fgsm_alpha = 1.25
                delta.data = torch.clamp(delta.data + fgsm_alpha * argmax_delta, -self.eps, self.eps)
                delta.data = torch.clamp(X + delta.data, 0, 1) - X
                delta = delta.detach()

                predictions = model(X + delta)
                loss_function = torch.nn.CrossEntropyLoss()
                train_loss = loss_function(predictions, y)
                reg = self.get_ga_reg(model, data, labels, device, double_bp)
                train_loss += reg

                opt.zero_grad()
                train_loss.backward()
                opt.step()

            elif self.mode == "fast_bat":
                z_init = torch.clamp(
                    data + torch.FloatTensor(data.shape).uniform_(-self.eps, self.eps).to(device),
                    min=0, max=1
                ) - data
                z_init.requires_grad_(True)

                model.clear_grad()
                model.with_grad()
                attack_loss = _attack_loss(model(data + z_init), labels)
                grad_attack_loss_delta = torch.autograd.grad(attack_loss, z_init, retain_graph=True, create_graph=True)[
                    0] # does not need to retain graph here
                delta = z_init - self.attack_lr * grad_attack_loss_delta # warm up with nonsign pgd
                delta = torch.clamp(delta, min=-self.eps, max=self.eps)
                delta = torch.clamp(data + delta, min=0, max=1) - data

                delta = delta.detach().requires_grad_(True)
                attack_loss_second = _attack_loss(model(data + delta), labels)
                grad_attack_loss_delta_second = \
                    torch.autograd.grad(attack_loss_second, delta, retain_graph=True, create_graph=True)[0] \
                        .view(real_batch, 1, channels * image_size * image_size)
                delta_star = delta - self.attack_lr * grad_attack_loss_delta_second.detach().view(data.shape)
                delta_star = torch.clamp(delta_star, min=-self.eps, max=self.eps)
                delta_star = torch.clamp(data + delta_star, min=0, max=1) - data
                z = delta_star.clone().detach().view(real_batch, -1)

                if self.constraint_type == np.inf:
                    # H: (batch, channel * image_size * image_size)
                    z_min = torch.max(-data.view(real_batch, -1),
                                      -self.eps * torch.ones_like(data.view(real_batch, -1)))
                    z_max = torch.min(1 - data.view(real_batch, -1),
                                      self.eps * torch.ones_like(data.view(real_batch, -1)))
                    H = ((z > z_min + 1e-7) & (z < z_max - 1e-7)).to(torch.float32)
                else:
                    raise NotImplementedError

                delta_cur = delta_star.detach().requires_grad_(True)

                model.no_grad()
                lgt = model(data + delta_cur)
                delta_star_loss = loss_func(lgt, labels)
                delta_star_loss.backward()
                delta_outer_grad = delta_cur.grad.view(real_batch, -1) # (batch, channel * image_size * image_size)

                hessian_inv_prod = delta_outer_grad / self.lmbda
                bU = (H * hessian_inv_prod).unsqueeze(-1) # (batch, channel * image_size * image_size, 1)

                model.with_grad()
                model.clear_grad()
                b_dot_product = grad_attack_loss_delta_second.bmm(bU).view(-1).sum(dim=0)
                b_dot_product.backward()
                cross_term = [-param.grad / real_batch for param in model.parameters()]

                model.clear_grad()
                model.with_grad()
                predictions = model(data + delta_star) # delta_star may not be diff wrt model params bc of projection step to obtain delta_star
                train_loss = loss_func(predictions, labels) / real_batch
                opt.zero_grad()
                train_loss.backward()

                with torch.no_grad():
                    for p, cross in zip(model.parameters(), cross_term):
                        new_grad = p.grad + 1. * cross
                        p.grad.copy_(new_grad)

                del cross_term, H, grad_attack_loss_delta_second
                opt.step()

            else:
                raise NotImplementedError()

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == labels # was predictions.data
                if self.log is not None:
                    self.log(model,
                             loss=train_loss.cpu(),
                             accuracy=correct.cpu(),
                             learning_rate= opt.param_groups[0]['lr'], #scheduler.get_last_lr()[0],
                             batch_size=real_batch)
            if scheduler:
                scheduler.step()

            training_loss += train_loss.cpu().sum().item()
            train_ra += correct.cpu().sum().item()
            if wandb:
                wandb.log({"training_ra": 100.*train_ra/total}, step=ds*epoch + i)
                wandb.log({"training_sa": 100.*train_sa/total}, step=ds*epoch + i)
                wandb.log({"lr": opt.param_groups[0]['lr']}, step=ds*epoch + i)

        return model

    def get_ga_reg(self, model, data, labels, device, double_bp):
        # Regularization for Gradient Alignment
        reg = torch.zeros(1).to(device)[0]
        delta = torch.zeros_like(data, requires_grad=True)
        output = model(torch.clamp(data + delta, 0, 1))
        clean_train_loss = F.cross_entropy(output, labels)
        grad = torch.autograd.grad(clean_train_loss, delta, create_graph=True if double_bp else False)[0]
        grad = grad.detach()

        if self.args.ga_coef != 0.0:
            grad_random_perturb = self.get_input_grad(model, data, labels, self.eps,
                                                      delta_init='random_uniform',
                                                      backprop=True)
            grads_nnz_idx = ((grad ** 2).sum([1, 2, 3]) ** 0.5 != 0) * (
                    (grad_random_perturb ** 2).sum([1, 2, 3]) ** 0.5 != 0)
            grad_clean_data, grad_random_perturb = grad[grads_nnz_idx], grad_random_perturb[grads_nnz_idx]
            grad_clean_data_norms, grad_random_perturb_norms = l2b(grad_clean_data), l2b(
                grad_random_perturb)
            grad_clean_data_normalized = grad_clean_data / grad_clean_data_norms[:, None, None, None]
            grad_random_perturb_normalized = grad_random_perturb / grad_random_perturb_norms[:, None, None,
                                                                   None]
            cos = torch.sum(grad_clean_data_normalized * grad_random_perturb_normalized, (1, 2, 3))
            reg += self.args.ga_coef * (1.0 - cos.mean())

        return reg

    def eval(self, model, test_dl, attack_eps, attack_steps, attack_lr, attack_rs, device):
        total = 0
        robust_total = 0
        correct_total = 0
        test_loss = 0

        for ii, (data, labels) in enumerate(test_dl):
            data = data.type(torch.FloatTensor)
            data = data.to(device)
            labels = labels.to(device)
            real_batch = data.shape[0]
            total += real_batch

            with ctx_noparamgrad(model):
                perturbed_data = attack_pgd_restart(
                    model=model,
                    X=data,
                    y=labels,
                    eps=attack_eps,
                    alpha=attack_lr,
                    attack_iters=attack_steps,
                    n_restarts=attack_rs,
                    rs=(attack_rs > 1),
                    verbose=False,
                    linf_proj=True,
                    l2_proj=False,
                    l2_grad_update=False,
                    cuda=True
                ) + data

            if attack_steps == 0:
                perturbed_data = data

            predictions = model(data)
            correct = torch.argmax(predictions, 1) == labels
            correct_total += correct.sum().cpu().item()

            predictions = model(perturbed_data)
            robust = torch.argmax(predictions, 1) == labels
            robust_total += robust.sum().cpu().item()

            robust_loss = torch.nn.CrossEntropyLoss()(predictions, labels)
            test_loss += robust_loss.cpu().sum().item()

            if self.log:
                self.log(model=model,
                         accuracy=correct.cpu(),
                         robustness=robust.cpu(),
                         batch_size=real_batch)

        return correct_total, robust_total, total, test_loss / total

    def eval_per_class(self, model, test_dl, attack_eps, attack_steps, attack_lr, attack_rs, device, num_classes):
        total = 0
        robust_total = 0
        correct_total = 0
        test_loss = 0
        classes = list(range(num_classes))
        total_counter = dict(zip(classes, np.zeros(num_classes)))
        robust_counter = dict(zip(classes, np.zeros(num_classes)))
        correct_counter = dict(zip(classes, np.zeros(num_classes)))


        for ii, (data, labels) in enumerate(test_dl):
            data = data.type(torch.FloatTensor)
            data = data.to(device)
            labels = labels.to(device)
            real_batch = data.shape[0]
            total += real_batch

            with ctx_noparamgrad(model):
                perturbed_data = attack_pgd_restart(
                    model=model,
                    X=data,
                    y=labels,
                    eps=attack_eps,
                    alpha=attack_lr,
                    attack_iters=attack_steps,
                    n_restarts=attack_rs,
                    rs=(attack_rs > 1),
                    verbose=False,
                    linf_proj=True,
                    l2_proj=False,
                    l2_grad_update=False,
                    cuda=True
                ) + data

            if attack_steps == 0:
                perturbed_data = data

            predictions = model(data)
            correct = torch.argmax(predictions, 1) == labels
            correct_total += correct.sum().cpu().item()

            robust_predictions = model(perturbed_data)
            robust = torch.argmax(robust_predictions, 1) == labels
            robust_total += robust.sum().cpu().item()

            robust_loss = torch.nn.CrossEntropyLoss()(robust_predictions, labels)
            test_loss += robust_loss.cpu().sum().item()

            predictions = torch.argmax(predictions, 1).cpu().numpy()
            robust_predictions = torch.argmax(robust_predictions, 1).cpu().numpy()
            labels = labels.cpu().numpy()
            for (label, pred, robust_pred) in zip(labels, predictions, robust_predictions):
                total_counter[label] += 1.
                if pred == label:
                    correct_counter[label] += 1.
                if robust_pred == label:
                    robust_counter[label] += 1.

            if self.log:
                self.log(model=model,
                         accuracy=correct.cpu(),
                         robustness=robust.cpu(),
                         batch_size=real_batch)

        SAs = np.array([correct_counter[c] / total_counter[c] if (total_counter[c]!=0.) else np.inf for c in classes])
        RAs = np.array([robust_counter[c] / total_counter[c] if (total_counter[c]!=0.) else np.inf for c in classes])
        idx_tail_30 = list(np.argsort(RAs)[:int(0.3*num_classes)])

        cra30, csa30, t30 = 0., 0., 0.
        # for c in classes:
        for c in idx_tail_30:
            cra30 += robust_counter[c]
            csa30 += correct_counter[c]
            t30 += total_counter[c]

        ra_tail_30 = cra30 / t30
        sa_tail_30 = csa30 / t30


        out_str = '\n================ Per class accuracies =================='
        for c in range(num_classes):
            out_str += f'\nRA for class {c}: {100.*RAs[c]:10.4f} ┃ SA for class {c}: {100.*SAs[c]:10.4f}'

        out_str += f'\nTail-30 classes: {idx_tail_30}'
        out_str += f'\nRA tail-30: {100.*ra_tail_30:10.4f} ┃ SA tail-30: {100.*sa_tail_30:10.4f}'


        return correct_total, robust_total, total, test_loss / total, out_str, ra_tail_30, sa_tail_30, RAs, SAs


def norm(x):
    return torch.sqrt(torch.sum(x * x))
