import math
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from scipy import linalg
from torch import Tensor
from torch.autograd import grad
from tqdm.auto import tqdm

from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker
from greatx.functional import to_dense_adj
from torch_geometric.utils import get_laplacian, dense_to_sparse

class PGD:
    """Base class for :class:`PGDAttack`."""
    # PGDAttack cannot ensure that there is not singleton node after attacks.
    _allow_singleton: bool = True
    is_undirected: bool = True

    def attack(
        self,
        num_budgets: int,
        base_lr: float = 0.1,
        grad_clip: Optional[float] = None,
        epochs: int = 100, # follow the authors' hyper
        approx: bool = True,
        sample_epochs: int = 20,
        disable: bool = False,
    ) -> "PGD":

        self.approx = approx        
        perturbations = self.perturbations
        L_edge_index, L_edge_weight = get_laplacian(self.edge_index, self.edge_weight, normalization='sym')
        self.L = to_dense_adj(L_edge_index, L_edge_weight, self.num_nodes)
        self.ori_eigenval, self.ori_eigenvec = torch.linalg.eigh(self.L)

        for epoch in tqdm(range(epochs), desc='PGD training...',
                          disable=disable):
            lr = base_lr * epochs / math.sqrt(epoch + 1) # T * lr / sqrt{t}
            gradients = self.compute_gradients(perturbations, epoch)

            with torch.no_grad():
                perturbations.data.add_(lr * gradients)
                if perturbations.clamp(0, 1).sum() <= self.num_budgets:
                    perturbations.clamp_(0, 1)
                else:
                    top = perturbations.max().item()
                    bot = (perturbations.min() - 1).clamp_min(0).item()
                    mu = (top + bot) / 2
                    while (top - bot) / 2 > 1e-5:
                        used_budget = (perturbations - mu).clamp(0, 1).sum()
                        if used_budget == self.num_budgets:
                            break
                        elif used_budget > self.num_budgets:
                            bot = mu
                        else:
                            top = mu
                        mu = (top + bot) / 2
                    perturbations.sub_(mu).clamp_(0, 1)

        best_loss = -np.inf
        best_pert = None

        perturbations.detach_()
        for it in tqdm(range(sample_epochs), desc='Bernoulli sampling...',
                       disable=disable):
            sampled = perturbations.bernoulli()
            if sampled.count_nonzero() <= self.num_budgets:
                loss = self.compute_loss(symmetric(sampled), 0)
                if best_loss < loss:
                    best_loss = loss
                    best_pert = sampled

        row, col = torch.where(best_pert > 0.)
        for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())):
            if self.adj[u, v] > 0:
                self.remove_edge(u, v, it)
            else:
                self.add_edge(u, v, it)

        return self

    def compute_loss(
        self,
        perturbations: Tensor,
        epoch,
    ) -> Tensor:
        modified_adj = self.adj + perturbations * (1 - 2 * self.adj)
        mod_edge_index, mod_edge_weight = dense_to_sparse(modified_adj)
        L_edge_index, L_edge_weight = get_laplacian(mod_edge_index, mod_edge_weight, normalization='sym')
        mod_L = to_dense_adj(L_edge_index, L_edge_weight, self.num_nodes)

        # Spectral loss
        # Compute approx every 10 steps
        k1 = 128
        k2 = 64
        n = self.num_nodes
        spec_dist = 0
        if self.approx and epoch % 10 != 0:            
            delta_L = mod_L - self.L
            delta_sum = delta_L.sum(dim=1)
            diag_delta_L = torch.diag(delta_sum)
            tmp_spec_dist = 0
            for i in range(k1):
                u_i = self.ori_eigenvec[:, i].reshape(n, 1)
                tmp_spec_dist += (u_i.T @ (delta_L - self.ori_eigenval[i] * diag_delta_L) @ u_i)**2
            for i in range(-k2, 0):
                u_i = self.ori_eigenvec[:, i].reshape(n, 1)
                tmp_spec_dist += (u_i.T @ (delta_L - self.ori_eigenval[i] * diag_delta_L) @ u_i)**2
            spec_dist += torch.sqrt(tmp_spec_dist)
        elif self.approx:
            mod_eigval, _ = torch.linalg.eigh(mod_L)
            tmp_spec_dist = 0
            for i in range(k1):
                tmp_spec_dist += (mod_eigval[i] - self.ori_eigenval[i])**2
            for i in range(-k2, 0):
                tmp_spec_dist += (mod_eigval[i] - self.ori_eigenval[i])**2
            spec_dist += torch.sqrt(tmp_spec_dist)
        else:
            mod_eigval, _ = torch.linalg.eigh(mod_L)
            spec_dist = torch.norm(mod_eigval - self.ori_eigenval, p=2)

        loss = spec_dist
        return loss

    def compute_gradients(
        self,
        perturbations: Tensor,
        epoch, 
    ) -> Tensor:
        pert_sym = symmetric(perturbations)
        grad_outputs = grad(
            self.compute_loss(pert_sym, epoch), pert_sym)
        return grad(pert_sym, perturbations, grad_outputs=grad_outputs[0])[0]

 
class SpecAttack(UntargetedAttacker, PGD):
    r"""Implementation of `SPAC` attack from the:
    `"Graph Structural Attack by Perturbing Spectral Distance"
    <https://arxiv.org/abs/2111.00684> KDD 2022
    """

    _allow_singleton: bool = True

    def reset(self) -> "PGDAttack":
        super().reset()
        self.adj = self.get_dense_adj()
        return self

    def attack(
        self,
        num_budgets: Union[int, float] = 0.05,
        *,
        base_lr: float = 0.1,
        grad_clip: Optional[float] = None,
        epochs: int = 200,
        approx: bool = True,
        sample_epochs: int = 20,
        structure_attack: bool = True,
        feature_attack: bool = False,
        disable: bool = False,
    ) -> "PGDAttack":
        super().attack(num_budgets=num_budgets,
                       structure_attack=structure_attack,
                       feature_attack=feature_attack)

        density = self.edge_index.shape[1] / (self.num_nodes ** 2)

        # Initalize perturbation
        self.perturbations = (torch.ones_like(self.adj) * density * num_budgets).requires_grad_()

        return PGD.attack(
            self,
            self.num_budgets,
            base_lr=base_lr,
            grad_clip=grad_clip,
            epochs=epochs,
            approx=approx,
            sample_epochs=sample_epochs,
            disable=disable,
        )


def symmetric(x: Tensor) -> Tensor:
    x = x.triu(diagonal=1)
    return x + x.T


def margin_loss(logit: Tensor, y_true: Tensor) -> Tensor:
    all_nodes = torch.arange(y_true.size(0))
    # Get the scores of the true classes.
    scores_true = logit[all_nodes, y_true]
    # Get the highest scores when not considering the true classes.
    scores_mod = logit.clone()
    scores_mod[all_nodes, y_true] = -np.inf
    scores_pred_excl_true = scores_mod.amax(dim=-1)
    return -(scores_true - scores_pred_excl_true).tanh().mean()


def cross_entropy_loss(logit: Tensor, y_true: Tensor) -> Tensor:
    return F.cross_entropy(logit, y_true)
