from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.utils import coalesce, to_undirected, get_laplacian
from tqdm.auto import tqdm
from kmeans_pytorch import kmeans
from sklearn.cluster import KMeans

from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker
from greatx.attack.untargeted.utils import linear_to_triu_idx, propagate, num_possible_edges, project
from greatx.functional import to_dense_adj
from greatx.nn.models.surrogate import Surrogate
from greatx.nn.models import GCN

class CosAttack:
    """Base class for SheAttack"""

    _allow_singleton: bool = False
    is_undirected: bool = True
    coeffs: Dict[str, Any] = {
    'max_final_samples': 20,
    'max_trials_sampling': 20,
    'with_early_stopping': True,
    'eps': 1e-7
    }

    def she(self, ptb_edge_index, ptb_edge_weight):
        """
        Compute the silhouette(she) score loss.
        """
        x = propagate(self.proj_feat, ptb_edge_index, ptb_edge_weight) 
        k = self.k
        ### Conduct clustering 
        # Set nlayers = 1 for ogbn-products
        # Set nlayers = 0 for cSBM  graphs
        if self.cluster_id is None:
            if self.setting in ['black', 'white']: 
                cluster_feat = self.proj_feat # For cSBM models
                #cluster_feat = propagate(self.proj_feat, self.edge_index, self.edge_weight) 
            elif self.setting == 'soft':
                cluster_feat = self.surrogate(self.feat, self.edge_index, self.edge_weight)
                cluster_feat = F.softmax(cluster_feat, dim=-1)
            
            # Conduct Kmeans
            # Implementation in both Pytorch kmeans(https://github.com/subhadarship/kmeans_pytorch) and 
            # Sklearn kmeans are applicable.
            # Empirically, Pytorch-kmeans is faster, sklearn kmeans is more stable. 
            # For small graphs, we adopt Sklearn implementation.
            # For large graphs, we adopt Pytorch-Kmeans. 

            if x.shape[0] >= 10_000: 
                cluster_id, centroids = kmeans(X=cluster_feat, num_clusters=k, distance='cosine', iter_limit=1000, 
                                            tqdm_flag=False, device=self.device)
            else:
                model = KMeans(n_clusters=k, n_init=50)
                model.fit(cluster_feat.cpu().detach().numpy())
                cluster_id = torch.from_numpy(model.labels_).to(torch.int64)
                centroids = torch.from_numpy(model.cluster_centers_)

            self.cluster_id = cluster_id.to(self.device)
            self.centroids = centroids.to(self.device)
            if self.setting == 'white':
                labels = self.label
            else:
                labels = self.cluster_id

            if self.approx:
                self.mask_a = torch.zeros(self.num_nodes, self.k, device=self.device, dtype=torch.bool)
                self.mask_a.scatter_(1, labels.unsqueeze(1), 1)
                self.mask_b = ~self.mask_a
            else:
                self.mask_a = (labels.unsqueeze(0) == labels.unsqueeze(1))
                self.mask_b = ~self.mask_a

        if self.approx:
            dist_matrix = torch.cdist(x, self.centroids, p=2)
            a = (dist_matrix * self.mask_a).sum(dim=1) / self.mask_a.sum(dim=1)
            _, indices = torch.topk(dist_matrix * self.mask_b, k=2, dim=1, largest=False)
            second_smallest = torch.gather(dist_matrix * self.mask_b, 1, indices[:, 1].unsqueeze(1))
            b = second_smallest
        else:
            dist_matrix = torch.cdist(x, x, p=2)
            a = (dist_matrix * self.mask_a).sum(dim=1) / (self.mask_a.sum(dim=1) - 1)
            b = (dist_matrix * self.mask_b).sum(dim=1) / self.mask_b.sum(dim=1)

        # Compute silhouette score for each data point
        s = (b - a) / torch.max(a, b)
        if self.reg > 0.0:
            x = self.proj_feat
            pro_x = propagate(x, self.edge_index, self.edge_weight)
            pro_ptb_x = propagate(x, ptb_edge_index, ptb_edge_weight)
            diff = pro_x - pro_ptb_x
            row_shift = torch.norm(diff, p=2, dim=1)
            shift = torch.mean(row_shift)
            loss = -s.mean() + self.reg * shift
        else:
            loss = -s.mean()
        return loss 


    def self_atk(self, ptb_edge_index, ptb_edge_weight):
        """
        Attack trained surrogate model based on Kmeans labels
        """
        x = propagate(self.proj_feat, ptb_edge_index, ptb_edge_weight) 
        k = self.k
        ### Conduct clustering ###
        if self.cluster_id is None:
            cluster_feat = propagate(self.proj_feat, self.edge_index, self.edge_weight) 
            if x.shape[0] >= 10_000:
                cluster_id, centroids = kmeans(X=cluster_feat, num_clusters=k, distance='cosine', iter_limit=1000, 
                                            tqdm_flag=False, device=self.device)
            else:
                model = KMeans(n_clusters=k, n_init=50)
                model.fit(cluster_feat.cpu().detach().numpy())
                cluster_id = torch.from_numpy(model.labels_).to(torch.int64)
                centroids = torch.from_numpy(model.cluster_centers_)

            self.cluster_id = cluster_id.to(self.device)
            self.centroids = centroids.to(self.device)

            surrogate = GCN(self.feat.shape[1], k, dropout=0.5, hids=16).to(self.device)
            optimizer = torch.optim.Adam(surrogate.parameters(), lr=0.01, weight_decay=5e-4)
            best_loss = 1e9
            for epoch in range(200):
                surrogate.train()
                optimizer.zero_grad()
                output = surrogate(self.feat, self.edge_index)
                loss = F.cross_entropy(output, self.cluster_id)
                if loss.item() < best_loss:
                    weights = surrogate.state_dict()
                    best_loss = loss.item()
                loss.backward()
                optimizer.step()
            surrogate.load_state_dict(weights)
            self.model = surrogate

        self.model.eval()
        output = self.model(self.feat, ptb_edge_index, ptb_edge_weight)
        loss = F.cross_entropy(output, self.cluster_id)
        return loss 



    def peega(self, ptb_edge_index, ptb_edge_weight):
        """
        Return the feature difference between connected nodes.
        Larger gap is better, return homo
        """
        x = self.proj_feat
        ori_edge_index = self.edge_index
        ori_edge_weight = torch.ones(self.edge_index.shape[1]).to(self.device)
        num_nodes = x.shape[0]
        num_edges = ptb_edge_index.shape[1]
        # Original propagated feature
        ori_pro_x = propagate(x, ori_edge_index, ori_edge_weight)
        # Propagated feature after perturbation 
        pro_x = propagate(x, ptb_edge_index, ptb_edge_weight)

        diff = pro_x[ori_edge_index[0], :] - ori_pro_x[ori_edge_index[1], :]
        row_diff = torch.norm(diff, p=2, dim=1)
        homo = torch.mean(row_diff)

        if self.reg > 0.0:
            x = self.proj_feat
            pro_x = propagate(x, self.edge_index, self.edge_weight)
            pro_ptb_x = propagate(x, ptb_edge_index, ptb_edge_weight)
            diff = pro_x - pro_ptb_x
            row_shift = torch.norm(diff, p=2, dim=1)
            shift = torch.mean(row_shift)
            loss = homo + self.reg * shift
        else:
            loss = homo
        return loss


    def shift(self, ptb_edge_index, ptb_edge_weight):
        """
        Return the change of features after perturbation
        Larger change is better, return shift
        """
        x = self.proj_feat
        pro_x = propagate(x, self.edge_index, self.edge_weight)
        pro_ptb_x = propagate(x, ptb_edge_index, ptb_edge_weight)
        diff = pro_x - pro_ptb_x
        row_shift = torch.norm(diff, p=2, dim=1)
        shift = torch.mean(row_shift)
        return shift


    def spec(self, ptb_edge_index, ptb_edge_weight):
        """
        Spectral loss, from SPAC
        Larger spectral loss is better, return spec_dist
        """
        modified_L_edge_index, modified_L_edge_weight = get_laplacian(ptb_edge_index, ptb_edge_weight, normalization='sym')
        mod_L = to_dense_adj(modified_L_edge_index, modified_L_edge_weight, self.num_nodes)
        mod_eigval, mod_eigvec = torch.linalg.eigh(mod_L)
        spec_dist = torch.norm(mod_eigval - self.ori_eigval, p=2)
        return spec_dist


    def gf(self, ptb_edge_index, ptb_edge_weight):
        """
        Attack loss in GF-Attack for GCN/SGC
        """
        x_mean = self.feat.sum(1)
        # Obtain modified dense A
        mod_dense_A = to_dense_adj(ptb_edge_index, ptb_edge_weight, self.num_nodes)
        deg = torch.sum(mod_dense_A, dim=1)
        deg_sqrt_inv = torch.sqrt(1.0 / deg)
        mod_A = deg_sqrt_inv.unsqueeze(1) * mod_dense_A * deg_sqrt_inv.unsqueeze(0)

        # Here we compute explicit eigen-decomposition for GF attack loss
        mod_eigval, mod_eigvec = torch.linalg.eigh(mod_A)
        T = 128
        K = 2
        mod_eigval = (mod_eigval + 1.).square().pow(K)
        # from small to large
        least_t = torch.topk(mod_eigval, k=T, largest=False).indices
        eig_vals_k_sum = mod_eigval[least_t].sum()
        u_k = self.ori_eigvec[:, least_t]
        u_x_mean = u_k.t() @ x_mean
        score = eig_vals_k_sum * torch.square(torch.linalg.norm(u_x_mean))
        loss = score / 1e8
        return loss


class GCosAttack(UntargetedAttacker, CosAttack, Surrogate):
    """
    Greedy attack based on specific loss
    """
    def reset(self) -> "GCOSAttack":
        super().reset()
        self.current_block = None
        self.block_edge_index = None
        self.block_edge_weight = None
        self.loss = None

        self.proj_feat = self.get_proj_feat()
        self._edge_index = self.edge_index
        self._edge_weight = torch.ones(self.num_edges, device=self.device)

        self.flipped_edges = self._edge_index.new_empty(2, 0)
        return self

    def setup_surrogate(
        self,
        surrogate: torch.nn.Module,
        tau: float = 1.0,
        freeze: bool = True,
    ):
        Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau,
                                  freeze=freeze)
        return self

    def attack(
        self,
        num_budgets: Union[int, float] = 0.05,
        *,
        block_size: int = 250_000,
        epochs: int = 125,
        epochs_resampling: int = 100,
        loss: Optional[str] = 'she',
        lr: float = 1000,
        structure_attack: bool = True,
        feature_attack: bool = False,
        k: int = -1,
        setting: str = "black",
        reg: float = 0.0, 
        approx: bool = False,
        disable: bool = False,
        **kwargs,
    ) -> "GCosAttack":
        super().attack(num_budgets=num_budgets,
                       structure_attack=structure_attack,
                       feature_attack=feature_attack)
        self.block_size = block_size

        # Specify loss
        if loss == 'peega':
            self.loss = self.peega
        elif loss == 'spec':
            self.loss = self.spec
            L_edge_index, L_edge_weight = get_laplacian(self.edge_index, self.edge_weight, normalization='sym')
            self.L = to_dense_adj(L_edge_index, L_edge_weight, self.num_nodes)
            self.ori_eigval, self.ori_eigvec = torch.linalg.eigh(self.L)
        elif loss == 'gf':
            self.loss = self.gf
            adj = self.adjacency_matrix
            adj = adj + sp.eye(adj.shape[0], format='csr')
            deg = np.diag(adj.sum(1).A1)
            eig_vals, eig_vec = linalg.eigh(adj.A, deg)
            self.ori_eigval = torch.as_tensor(eig_vals, device=self.device,
                                            dtype=torch.float32)
            self.ori_eigvec = torch.as_tensor(eig_vec, device=self.device,
                                        dtype=torch.float32)
        elif loss == 'shift':
            self.loss = self.shift
        elif loss == 'self_atk':
            self.loss = self.self_atk
        elif loss == 'she':
            self.loss = self.she

        self.epochs_resampling = epochs_resampling
        self.lr = lr
        # k can be set to -1, -2, indicating |k|*self.num_classes clusters
        if k <= 0:
            self.k = -k * self.num_classes
        else:
            self.k = k

        self.setting = setting
        self.reg = reg
        self.approx = approx

        for step in tqdm(self.prepare(self.num_budgets, epochs),
                         desc='Peturbing graph...', disable=disable):
            loss, gradient = self.compute_gradients()
            self.update(step, gradient)
        flipped_edges = self.get_flipped_edges()
        flipped_edges = coalesce(flipped_edges)
        assert flipped_edges.size(1) <= self.num_budgets, (
            f'# perturbed edges {flipped_edges.size(1)} '
            f'exceeds num_budgets {self.num_budgets}')

        for it, (u, v) in enumerate(zip(*flipped_edges.tolist())):
            if self.adjacency_matrix[u, v] > 0:
                self.remove_edge(u, v, it)
            else:
                self.add_edge(u, v, it)
        return self

    def compute_gradients(self) -> Tuple[Tensor, Tensor]:
        self.block_edge_weight.requires_grad_()
        edge_index, edge_weight = self.get_modified_graph(
            self._edge_index, self._edge_weight, self.block_edge_index,
            self.block_edge_weight)
        loss = self.loss(edge_index, edge_weight)
        gradient = torch.autograd.grad(loss, self.block_edge_weight)[0]
        return loss, gradient

    def get_modified_graph(
        self,
        edge_index: Tensor,
        edge_weight: Tensor,
        block_edge_index: Tensor,
        block_edge_weight: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        if self.is_undirected:
            block_edge_index, block_edge_weight = to_undirected(
                block_edge_index, block_edge_weight, num_nodes=self.num_nodes,
                reduce='mean')

        modified_edge_index = torch.cat((edge_index, block_edge_index), dim=-1)
        modified_edge_weight = torch.cat((edge_weight, block_edge_weight))

        modified_edge_index, modified_edge_weight = coalesce(
            modified_edge_index, modified_edge_weight,
            num_nodes=self.num_nodes, reduce='sum')

        mask = modified_edge_weight > 1
        modified_edge_weight[mask] = 2 - modified_edge_weight[mask]

        return modified_edge_index, modified_edge_weight


    def prepare(self, num_budgets: int, epochs: int) -> List[int]:
        step_size = num_budgets // epochs
        if step_size > 0:
            steps = epochs * [step_size]
            for i in range(num_budgets % epochs):
                steps[i] += 1
        else:
            steps = [1] * num_budgets

        self.sample_random_block(step_size)
        return steps


    @torch.no_grad()
    def sample_random_block(self, num_budgets: int = 0):
        num_possible = num_possible_edges(self.num_nodes,
                                            self.is_undirected)

        for _ in range(self.coeffs['max_trials_sampling']):
            self.current_block = torch.randint(num_possible,
                                               (self.block_size, ),
                                               device=self.device)
            self.current_block = torch.unique(self.current_block, sorted=True)
            self.block_edge_index = linear_to_triu_idx(self.num_nodes, self.current_block)
            self.block_edge_weight = torch.full(self.current_block.shape,
                                                self.coeffs['eps'],
                                                device=self.device)
            if self.current_block.size(0) >= num_budgets:
                return

        raise RuntimeError("Sampling random block was not successful. "
                           "Please decrease `num_budgets`.")


    def update_edge_weights(self, epoch: int, gradient: Tensor):
        lr = (self.num_budgets / self.num_nodes * self.lr /
              np.sqrt(max(0, epoch - self.epochs_resampling) + 1))
        self.block_edge_weight.data.add_(lr * gradient)

    def _filter_self_loops_in_block(self, with_weight: bool):
        mask = self.block_edge_index[0] != self.block_edge_index[1]
        self.current_block = self.current_block[mask]
        self.block_edge_index = self.block_edge_index[:, mask]
        if with_weight:
            self.block_edge_weight = self.block_edge_weight[mask]

    @torch.no_grad()
    def update(
        self,
        step_size: int,
        gradient: Tensor,
    ) -> Dict[str, Any]:
        """Update edge weights given gradient."""

        #topk_edge_index = torch.randperm(len(gradient))[:step_size]
        _, topk_edge_index = torch.topk(gradient, step_size)
        flip_edge_index = self.block_edge_index[:, topk_edge_index].to(
            self.device)
        flip_edge_weight = torch.ones(flip_edge_index.size(1),
                                      device=self.device)

        self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index),
                                       axis=-1)

        flip_edge_index, flip_edge_weight = to_undirected(
            flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes,
            reduce='mean')

        edge_index = torch.cat((self._edge_index, flip_edge_index), dim=-1)
        edge_weight = torch.cat((self._edge_weight, flip_edge_weight))

        edge_index, edge_weight = coalesce(edge_index, edge_weight,
                                           num_nodes=self.num_nodes,
                                           reduce='sum')
        mask = torch.isclose(edge_weight, torch.tensor(1.))

        self._edge_index = edge_index[:, mask]
        self._edge_weight = edge_weight[mask]
        self.sample_random_block(step_size)


    def get_flipped_edges(self) -> Tensor:
        """Clean up and prepare return flipped edges."""
        return self.flipped_edges


class PCosAttack(GCosAttack):
    def reset(self) -> "PCosttack":
        super().reset()
        self.current_block = None
        self.block_edge_index = None
        self.block_edge_weight = None

        self.proj_feat = self.get_proj_feat()

        self._edge_index = self.edge_index
        self._edge_weight = torch.ones(self.num_edges, device=self.device)

        self.best_metric = float('-Inf')
        return self

    def prepare(self, num_budgets: int, epochs: int) -> Iterable[int]:
        """Prepare attack and return the iterable sequence steps."""
        self.sample_random_block(num_budgets)
        return range(epochs)


    def resample_random_block(self, num_budgets: int):
        sorted_idx = torch.argsort(self.block_edge_weight)
        keep_above = (self.block_edge_weight <=
                      self.coeffs['eps']).sum().long()
        if keep_above < sorted_idx.size(0) // 2:
            keep_above = sorted_idx.size(0) // 2
        sorted_idx = sorted_idx[keep_above:]

        self.current_block = self.current_block[sorted_idx]

        for _ in range(self.coeffs['max_trials_sampling']):
            n_edges_resample = self.block_size - self.current_block.size(0)
            num_possible = num_possible_edges(self.num_nodes,
                                              self.is_undirected)
            lin_index = torch.randint(num_possible, (n_edges_resample, ),
                                      device=self.device)

            current_block = torch.cat((self.current_block, lin_index))
            self.current_block, unique_idx = torch.unique(
                current_block, sorted=True, return_inverse=True)

            self.block_edge_index = linear_to_triu_idx(
                self.num_nodes, self.current_block)

            # Merge existing weights with new edge weights
            block_edge_weight_prev = self.block_edge_weight[sorted_idx]
            self.block_edge_weight = torch.full(self.current_block.shape,
                                                self.coeffs['eps'],
                                                device=self.device)

            self.block_edge_weight[
                unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev

            if not self.is_undirected:
                self._filter_self_loops_in_block(with_weight=True)

            if self.current_block.size(0) > num_budgets:
                return

        raise RuntimeError("Sampling random block was not successful."
                           "Please decrease `num_budgets`.")

    @torch.no_grad()
    def sample_final_edges(self) -> Tuple[Tensor, Tensor]:
        best_metric = float('-Inf')
        block_edge_weight = self.block_edge_weight
        block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0
        num_budgets = self.num_budgets
        feat = self.feat

        for i in range(self.coeffs['max_final_samples']):
            if i == 0:
                # In first iteration employ top k heuristic instead of sampling
                sampled_edges = torch.zeros_like(block_edge_weight)
                sampled_edges[torch.topk(block_edge_weight,
                                         num_budgets).indices] = 1
            else:
                sampled_edges = torch.bernoulli(block_edge_weight).float()
            if sampled_edges.sum() > num_budgets:
                # Allowed num_budgets is exceeded
                continue
            self.block_edge_weight = sampled_edges
            self._edge_index, self._edge_weight = self.get_modified_graph(
                self._edge_index, self._edge_weight, self.block_edge_index,
                self.block_edge_weight)
            metric = self.loss(self._edge_index, self._edge_weight)

            # Save best sample
            if metric > best_metric:
                best_metric = metric
                best_edge_weight = self.block_edge_weight.clone().cpu()

        flipped_edges = self.block_edge_index[:, best_edge_weight != 0]
        return flipped_edges

    @torch.no_grad()
    def update(self, epoch: int, gradient: Tensor) -> Dict[str, float]:
        """Update edge weights given gradient."""
        self.update_edge_weights(epoch, gradient)
        self.block_edge_weight = project(self.num_budgets,
                                         self.block_edge_weight,
                                         self.coeffs['eps'])

        topk_block_edge_weight = torch.zeros_like(self.block_edge_weight)
        topk_block_edge_weight[torch.topk(self.block_edge_weight,
                                          self.num_budgets).indices] = 1

        edge_index, edge_weight = self.get_modified_graph(
            self._edge_index, self._edge_weight, self.block_edge_index,
            topk_block_edge_weight)

        metric = self.loss(edge_index, edge_weight)
        if metric > self.best_metric:
            self.best_metric = metric
            self.best_block = self.current_block.cpu().clone()
            self.best_edge_index = self.block_edge_index.cpu().clone()
            self.best_pert_edge_weight = self.block_edge_weight.cpu().detach()

        if epoch < self.epochs_resampling - 1:
            self.resample_random_block(self.num_budgets)
        elif epoch == self.epochs_resampling - 1:
            self.current_block = self.best_block.to(self.device)
            self.block_edge_index = self.best_edge_index.to(self.device)
            block_edge_weight = self.best_pert_edge_weight.clone()
            self.block_edge_weight = block_edge_weight.to(self.device)

    def get_flipped_edges(self) -> Tensor:
        """Clean up and prepare return flipped edges."""
        self.current_block = self.best_block.to(self.device)
        self.block_edge_index = self.best_edge_index.to(self.device)
        self.block_edge_weight = self.best_pert_edge_weight.to(self.device)
        return self.sample_final_edges()