import math
from typing import Optional, Union

import scipy.sparse as sp
import numpy as np
import torch
import torch.nn.functional as F
from scipy import linalg
from torch import Tensor
from torch.autograd import grad
from tqdm.auto import tqdm

from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker
from greatx.functional import to_dense_adj
from torch_geometric.utils import get_laplacian, dense_to_sparse

class PGD:
    """Base class for :class:`PGDAttack`."""
    # PGDAttack cannot ensure that there is not singleton node after attacks.
    _allow_singleton: bool = True
    is_undirected: bool = True

    def attack(
        self,
        num_budgets: int,
        base_lr: float = 0.1,
        grad_clip: Optional[float] = None,
        epochs: int = 100, # follow the authors' hyper
        sample_epochs: int = 20,
        disable: bool = False,
    ) -> "PGD":

        perturbations = self.perturbations
        for epoch in tqdm(range(epochs), desc='PGD training...',
                          disable=disable):
            lr = base_lr * epochs / math.sqrt(epoch + 1)
            gradients = self.compute_gradients(perturbations, epoch)

            with torch.no_grad():
                perturbations.data.add_(lr * gradients)
                if perturbations.clamp(0, 1).sum() <= self.num_budgets:
                    perturbations.clamp_(0, 1)
                else:
                    top = perturbations.max().item()
                    bot = (perturbations.min() - 1).clamp_min(0).item()
                    mu = (top + bot) / 2
                    while (top - bot) / 2 > 1e-5:
                        used_budget = (perturbations - mu).clamp(0, 1).sum()
                        if used_budget == self.num_budgets:
                            break
                        elif used_budget > self.num_budgets:
                            bot = mu
                        else:
                            top = mu
                        mu = (top + bot) / 2
                    perturbations.sub_(mu).clamp_(0, 1)

        best_loss = -np.inf
        best_pert = None

        perturbations.detach_()
        for it in tqdm(range(sample_epochs), desc='Bernoulli sampling...',
                       disable=disable):
            sampled = perturbations.bernoulli()
            if sampled.count_nonzero() <= self.num_budgets:
                loss = self.compute_loss(symmetric(sampled), 0)
                if best_loss < loss:
                    best_loss = loss
                    best_pert = sampled

        row, col = torch.where(best_pert > 0.)
        for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())):
            if self.adj[u, v] > 0:
                self.remove_edge(u, v, it)
            else:
                self.add_edge(u, v, it)

        return self

    def compute_loss(
        self,
        perturbations: Tensor,
        epoch,
    ) -> Tensor:
        """
        Attack loss in GF-Attack for GCN/SGC
        """
        modified_adj = self.adj + perturbations * (1 - 2 * self.adj)
        # We add small self-loop term for computation stabiltiy
        modified_adj = modified_adj + (torch.eye(self.num_nodes)).to(self.device)
        x_mean = self.feat.sum(1)
        # Obtain modified dense A
        deg = torch.sum(modified_adj, dim=1)
        deg_sqrt_inv = torch.sqrt(1.0 / deg)
        mod_A = deg_sqrt_inv.unsqueeze(1) * modified_adj * deg_sqrt_inv.unsqueeze(0)

        # Here we compute full eigen-decomposition for GF attack loss
        mod_eigval, mod_eigvec = torch.linalg.eigh(mod_A)
        T = 128
        K = 2
        mod_eigval = (mod_eigval + 1.).square().pow(K)
        # from small to large
        least_t = torch.topk(mod_eigval, k=T, largest=False).indices
        eig_vals_k_sum = mod_eigval[least_t].sum()
        u_k = self.ori_eigvec[:, least_t]
        u_x_mean = u_k.t() @ x_mean
        score = eig_vals_k_sum * torch.square(torch.linalg.norm(u_x_mean))
        loss = score / 1e8
        return loss

    def compute_gradients(
        self,
        perturbations: Tensor,
        epoch, 
    ) -> Tensor:
        pert_sym = symmetric(perturbations)
        grad_outputs = grad(
            self.compute_loss(pert_sym, epoch), pert_sym)
        return grad(pert_sym, perturbations, grad_outputs=grad_outputs[0])[0]

 
class GFAttack(UntargetedAttacker, PGD):
    r"""Implementation of `GFA` attack from the:
    `"A Restricted Black - box Adversarial Framework Towards
    Attacking Graph Embedding Models"
    <https://arxiv.org/abs/1908.01297>`_ paper (AAAI'20)
    The original implementation focuses on targeted attack, we
    modified it into a untargeted version similar to SPAC.
    """

    _allow_singleton: bool = True

    def reset(self) -> "PGDAttack":
        super().reset()
        self.adj = self.get_dense_adj()
        return self

    def attack(
        self,
        num_budgets: Union[int, float] = 0.05,
        *,
        base_lr: float = 0.1,
        grad_clip: Optional[float] = None,
        epochs: int = 200,
        sample_epochs: int = 20,
        structure_attack: bool = True,
        feature_attack: bool = False,
        disable: bool = False,
    ) -> "PGDAttack":
        super().attack(num_budgets=num_budgets,
                       structure_attack=structure_attack,
                       feature_attack=feature_attack)

        density = self.edge_index.shape[1] / (self.num_nodes ** 2)
        adj = self.adjacency_matrix
        adj = adj + sp.eye(adj.shape[0], format='csr')
        deg = np.diag(adj.sum(1).A1)
        eig_vals, eig_vec = linalg.eigh(adj.A, deg)
        self.ori_eigval = torch.as_tensor(eig_vals, device=self.device,
                                        dtype=torch.float32)
        self.ori_eigvec = torch.as_tensor(eig_vec, device=self.device,
                                       dtype=torch.float32)

        # Initalize perturbation
        self.perturbations = (torch.ones_like(self.adj) * (density * num_budgets)).requires_grad_()

        return PGD.attack(
            self,
            self.num_budgets,
            base_lr=base_lr,
            grad_clip=grad_clip,
            epochs=epochs,
            sample_epochs=sample_epochs,
            disable=disable,
        )


def symmetric(x: Tensor) -> Tensor:
    x = x.triu(diagonal=1)
    return x + x.T


def margin_loss(logit: Tensor, y_true: Tensor) -> Tensor:
    all_nodes = torch.arange(y_true.size(0))
    # Get the scores of the true classes.
    scores_true = logit[all_nodes, y_true]
    # Get the highest scores when not considering the true classes.
    scores_mod = logit.clone()
    scores_mod[all_nodes, y_true] = -np.inf
    scores_pred_excl_true = scores_mod.amax(dim=-1)
    return -(scores_true - scores_pred_excl_true).tanh().mean()


def cross_entropy_loss(logit: Tensor, y_true: Tensor) -> Tensor:
    return F.cross_entropy(logit, y_true)
