import torch
import torch.nn.functional as F

from robust_diffusion.attacks.base_attack import SparseAttack
from evotorch.operators.base import CopyingOperator
from evotorch.core import Problem, SolutionBatch
from evotorch import Problem
from evotorch.algorithms import GeneticAlgorithm
from evotorch import operators as evo_ops
from evotorch.logging import StdOutLogger
from copy import deepcopy
from sklearn.linear_model import LogisticRegression

from torch_sparse import SparseTensor

class PositiveIntMutation(CopyingOperator):
    def __init__(self, problem, mutation_rate=0.1, toggle_rate=0.5):
        super().__init__(problem)
        self.mutation_rate = mutation_rate
        self.toggle_rate = toggle_rate
        self.last_variance = -1
    
    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        data = result.access_values()
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]
        toggle_mutations = torch.rand(size=mutant_data.shape, device=mutant_data.device) < self.toggle_rate
        new_vals = torch.randint(0, self.problem.upper_bounds, size=mutant_data.shape, device=data.device)
        new_vals[(mutant_data >= 0) & toggle_mutations] = -1
        new_vals[(mutant_data < 0) & (~toggle_mutations)] = -1
        # print("mutation number", mutation_mask.sum().item() / len(batch))
        data[mutation_mask] = new_vals
        # TODO: Leave some part of the population unmutated
        return result

class AdaptiveIntMutation(CopyingOperator):
    def __init__(self, problem, mutation_rate=0.1, toggle_rate=0.5, patience=10):
        super().__init__(problem)
        self.mutation_rate = mutation_rate
        self.initial_mutation_rate = mutation_rate
        self.patience = patience
        self.toggle_rate = toggle_rate

        self.min_mutation_rate = mutation_rate / 3
        self.max_mutation_rate = mutation_rate * 3

        self.variances = []
        self.second_previous_variance = -1
        self.variance_derivative = 0
        self.eta = -1 * mutation_rate / 10
    
    def set_variance(self, var):
        self.variances.append(var)

        if len(self.variances) < 10:
            return
        
        var_coef = var - torch.tensor(self.variances[-10]).mean()
        
        self.mutation_rate = self.mutation_rate + (var_coef / self.variances[-10].mean()) * self.mutation_rate
        self.mutation_rate = (self.mutation_rate).item()
        
        self.mutation_rate = min(self.mutation_rate, self.max_mutation_rate)
        self.mutation_rate = max(self.mutation_rate, self.min_mutation_rate)

        if not ((self.mutation_rate <= 0) | (self.mutation_rate >= 0)):
            # number is nan
            self.mutation_rate = self.initial_mutation_rate
        
    
    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)

        data = result.access_values()
        

        
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]
        toggle_mutations = torch.rand(size=mutant_data.shape, device=mutant_data.device) < self.toggle_rate
        new_vals = torch.randint(0, self.problem.upper_bounds, size=mutant_data.shape, device=data.device)
        new_vals[(mutant_data >= 0) & toggle_mutations] = -1
        new_vals[(mutant_data < 0) & (~toggle_mutations)] = -1
        # print("mutation number", mutation_mask.sum().item() / len(batch))
        data[mutation_mask] = new_vals
        # TODO: Leave some part of the population unmutated
        return result

class GraphEvalProblem(Problem):
    def __init__(self, attack_class, capacity=1, *args, **kwargs):
        self.capacity = capacity
        self.attack_class = attack_class
        super().__init__(*args, **kwargs)

    def _evaluate_batch(self, solutions: SolutionBatch):
        evaluation_result = []
        start_ind = 0
        while start_ind < len(solutions) - 1: # TODO: Check if it should be
            end_ind = min(start_ind + self.capacity, len(solutions))
            batch = solutions[start_ind:end_ind]
            outputs = [
            self.attack_class._create_perturbed_graph(attr=self.attack_class.attr, adj=self.attack_class.adj, 
                                          perturbation=perturbation.values, device=self.attack_class.device)
                for perturbation in batch]
            attr_instances = [output[0] for output in outputs]
            adj_instances = [output[1] for output in outputs]

            bulk_eval = self.attack_class.bulk_evaluation(
                attr_list=attr_instances, adj_list=adj_instances, 
                labels_list=[self.attack_class.labels] * len(attr_instances),
                model=self.attack_class.model, mask_attack=self.attack_class.mask_attack)
            evaluation_result.append(bulk_eval)
            start_ind = end_ind
        evaluation_result = torch.cat(evaluation_result)
        solutions.set_evals(evaluation_result)

class EvaAttack(SparseAttack):
    def __init__(
            self, metric=None, n_steps=1000,
            tournament_size=2, num_cross_over=30, num_population=500,
            mutation_rate=0.006, mutation_toggle_rate=0.0, mutation_method="uniform", training_idx=None,

            **kwargs):
        if metric is None:
            self.metric = self._metric

        super().__init__(**kwargs)
        # general attributes
        self.n_nodes = self.attr.shape[0]

        # self.idx_attack=idx_attack
        self.mask_attack = torch.zeros(self.n_nodes, dtype=torch.bool)
        self.mask_attack[self.idx_attack] = True

        # genetic algorithm parameters
        self.n_steps = n_steps
        self.tournament_size = tournament_size
        self.num_cross_over = num_cross_over
        self.mutation_rate = mutation_rate
        self.mutation_toggle_rate = mutation_toggle_rate
        self.num_population = num_population
        self.mutation_method=mutation_method

        self.model = self.attacked_model
        self.training_idx = training_idx
        # temporary variable        
        self.best_perturbation = None

        
        
    def _attack(self, n_perturbations: int, **kwargs):
        self.best_perturbation = self.find_optimal_perturbation(n_perturbations)
        self.attr_adversary, self.adj_adversary = self._create_perturbed_graph(self.attr, self.adj, self.best_perturbation)

    @staticmethod
    def _metric(attr, adj, labels, model, mask_test, perturbation_adj=None, device=None):
        if device == None:
            device = attr.device
        
        model.eval()
        accuracy = (model(attr.to(device), adj.to(device)).argmax(dim=1)[mask_test] == labels[mask_test]).sum().item() / len(labels[mask_test])
        return accuracy
    
    @staticmethod
    def _linear_to_triu_idx(n, lin_idx):
        """Linear index to upper triangular matrix without diagonal. This is
        similar to
        https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498
        with number nodes decremented and col index incremented by one.
        """
        nn = n * (n - 1)
        row_idx = n - 2 - torch.floor(
            torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long()
        col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div(
            (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor')
        return row_idx.to(lin_idx.device), col_idx.to(lin_idx.device)
    
    def _evaluate_sparse_perturbation(self, attr, adj, labels, mask_attack, perturbation, model, device=None):
        if device is None:
            device = attr.device

        pert_attr, pert_adj = self._create_perturbed_graph(attr, adj, perturbation)
        metric_result = self.metric(attr=pert_attr, adj=pert_adj, labels=labels, 
                                    model=model, mask_test=mask_attack, device=device)

        return metric_result
    
    def _create_perturbed_graph(self, attr, adj, perturbation, device=None):
        if device is None:
            device = attr.device

        valid_perturbation = perturbation[perturbation > 0]
        perturbations_rows, perturbation_cols = self._linear_to_triu_idx(self.n_nodes, valid_perturbation)
        
        pert_matrix = torch.sparse_coo_tensor(
            indices=torch.stack([
                torch.cat([perturbations_rows, perturbation_cols]), torch.cat([perturbation_cols, perturbations_rows])]).to(device),
                values=torch.ones(perturbations_rows.shape[0] * 2).to(device),
                size=(self.n_nodes, self.n_nodes)).coalesce().to(device)
    
        adj_original = adj.to_torch_sparse_coo_tensor()
        pert_adj = ((adj_original + pert_matrix) - (adj_original * pert_matrix * 2)).coalesce()
        if isinstance(pert_adj, torch.Tensor):
            pert_adj = SparseTensor.from_torch_sparse_coo_tensor(pert_adj)
        # import pdb; pdb.set_trace()
        return attr, pert_adj

    def find_optimal_perturbation(self, n_perturbations: int, **kwargs):
        evaluate_ga = lambda x: self._evaluate_sparse_perturbation(
            attr=self.attr, adj=self.adj, labels=self.labels, 
            mask_attack=self.mask_attack, perturbation=x, model=self.model, device=self.device)
        
        problem = Problem(
            "min", 
            objective_func=evaluate_ga,
            solution_length=n_perturbations,
            dtype=torch.int64,
            bounds=(0, (self.n_nodes * (self.n_nodes - 1)) // 2 - 1), device=self.device
        )

        mutation_class = PositiveIntMutation if self.mutation_method == "uniform" else AdaptiveIntMutation
        mutation_operator = mutation_class(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)
        searcher = GeneticAlgorithm(
            problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
                                mutation_operator],
            popsize=self.num_population, re_evaluate=False
        )
        # searcher = GeneticAlgorithm(
        #     problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
        #                         PositiveIntMutation(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)],
        #     popsize=self.num_population, re_evaluate=False
        # )
        logger = StdOutLogger(searcher, interval=1)
        
        # searcher.run(self.n_steps)
        for i in range(self.n_steps):
            searcher.step()
            evals = searcher.population.access_evals().reshape((-1,))
            population_var = evals.var()
            if mutation_class == AdaptiveIntMutation:
                mutation_operator.set_variance(population_var)
            self.objective_track.append(evals.min())

        best_perturbation = searcher.status["pop_best"].values
        return best_perturbation
    
class EvaFast(EvaAttack):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.capacity = kwargs.get("capacity", 1)
        print("eva fast called with capacity =", self.capacity)
        self.population_variance = -1
        self.objective_track = []
        self.debug = kwargs.get("debug", False)
        self.debug_info = []

    def find_optimal_perturbation(self, n_perturbations: int, **kwargs):
        evaluate_ga = lambda x: self._evaluate_sparse_perturbation(
            attr=self.attr, adj=self.adj, labels=self.labels, 
            mask_attack=self.mask_attack, perturbation=x, model=self.model, device=self.device)
        
        problem = GraphEvalProblem(
            objective_sense="min", 
            objective_func=evaluate_ga,
            solution_length=n_perturbations,
            dtype=torch.int64,
            bounds=(0, (self.n_nodes * (self.n_nodes - 1)) // 2 - 1), device=self.device,
            capacity=self.capacity, attack_class=self
        )

        mutation_class = PositiveIntMutation if self.mutation_method == "uniform" else AdaptiveIntMutation
        mutation_operator = mutation_class(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)
        searcher = GeneticAlgorithm(
            problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
                                mutation_operator],
            popsize=self.num_population, re_evaluate=False
        )
        logger = StdOutLogger(searcher, interval=100)
        
        for i in range(self.n_steps):
            searcher.step()
            evals = searcher.population.access_evals().reshape((-1,))
            population_var = evals.var()
            if mutation_class == AdaptiveIntMutation:
                mutation_operator.set_variance(population_var)
            if self.debug:
                self.debug_info.append(searcher.status["pop_best"].clone().access_values().cpu())
                # import pdb; pdb.set_trace()
            self.objective_track.append(evals.min())
            
        # searcher.run(self.n_steps)
        
        best_perturbation = searcher.status["pop_best"].values
        return best_perturbation

    def bulk_evaluation(self, attr_list, adj_list, labels_list, model, mask_attack):
        attr_bulk = torch.cat(attr_list, dim=0)
        adj_bulk = self.cat_block_sparse(adj_list, block_size=self.n_nodes)
        labels_bulk = torch.cat(labels_list, dim=0)

        model.eval()
        pred = model(attr_bulk, adj_bulk).argmax(dim=-1)
        hits = (pred == labels_bulk)
        hits_per_instance = hits.view(len(attr_list), -1)
        
        accs = hits_per_instance[:, mask_attack.nonzero(as_tuple=True)[0]].sum(dim=1) / mask_attack.sum()
        
        # print("average accuracy", accs.mean())
        return accs

    @staticmethod
    def cat_block_sparse(matrices, block_size):
        if isinstance(block_size, int):
            block_size = [block_size] * len(matrices)
        
        matrices_tensor = []
        for matrix in matrices:
            if isinstance(matrix, SparseTensor):
                matrices_tensor.append(matrix.to_torch_sparse_coo_tensor().coalesce())
            else:
                matrices_tensor.append(matrix.coalesce())
        index_offsets = [0] + list(torch.cumsum(torch.tensor(block_size), 0).numpy())
        indices_list = [item.indices() for item in matrices_tensor]
        indices_list = [indices_list[i] + index_offsets[i] for i in range(len(indices_list))]
        cat_indices = torch.cat(indices_list, dim=1)
        values_list = [item.values() for item in matrices_tensor]
        cat_values = torch.cat(values_list, dim=0)
        cat_indices.shape, cat_values.shape

        result = torch.sparse_coo_tensor(
            indices=cat_indices, values=cat_values, size=(sum(block_size), sum(block_size)))
        return result
        
    def get_pertubations(self):
        return self.adj_adversary, self.attr_adversary
    
class EvaLocal(EvaFast):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.delta = kwargs.get("delta", 1e4)
        print("eva local called with relative degree diff = ", self.delta)
        self.orig_degrees = self.adj.sum(dim=1)

    
    # def bulk_evaluation(self, attr_list, adj_list, labels_list, model, mask_attack):
    #     attr_bulk = torch.cat(attr_list, dim=0)
    #     adj_bulk = self.cat_block_sparse(adj_list, block_size=self.n_nodes)
    #     labels_bulk = torch.cat(labels_list, dim=0)

    #     model.eval()
    #     pred = model(attr_bulk, adj_bulk).argmax(dim=-1)
    #     hits = (pred == labels_bulk)
    #     hits_per_instance = hits.view(len(attr_list), -1)
        
    #     accs = hits_per_instance[:, mask_attack.nonzero(as_tuple=True)[0]].sum(dim=1) / mask_attack.sum()
    #     bulk_degrees = torch.sparse.sum(adj_bulk, 1).values().reshape((-1, self.attr.shape[0]))
    #     relative_degrees = bulk_degrees / self.orig_degrees
    #     constraint_violated = (relative_degrees > self.delta).sum(1)
    #     print((constraint_violated > 0).sum())
    #     # print("average accuracy", accs.mean())
    #     try:
    #         accs[(constraint_violated > 0)] = (1.00 + constraint_violated / self.attr.shape[1])[(constraint_violated > 0)]
    #     except:
            # raise 
    #     return accs
    def _local_adjacency_filter(self, perturbation_rows, perturbation_cols, device):
        pert_matrix = torch.sparse_coo_tensor(
            indices=torch.stack([
                torch.cat([perturbation_rows, perturbation_cols]), torch.cat([perturbation_cols, perturbation_rows])]).to(device),
                values=torch.ones(perturbation_rows.shape[0] * 2).to(device),
                size=(self.n_nodes, self.n_nodes)).coalesce().to(device)
        pert_degrees_mat = torch.sparse.sum(pert_matrix, 1)
        pert_degrees_idx = pert_degrees_mat.indices()
        pert_degrees_val = pert_degrees_mat.values().int()
        pert_degrees = torch.zeros((self.attr.shape[0], )).int().to(device)

        
        pert_degrees[pert_degrees_idx[0]] = pert_degrees_val
        relative_degrees = pert_degrees / self.orig_degrees
        violations = relative_degrees > (self.delta - 1)
        row_filter = ~(violations[perturbation_rows])
        col_filter = ~(violations[perturbation_cols])
        filtered_rows = perturbation_rows[row_filter & col_filter]
        filtered_cols = perturbation_cols[col_filter & row_filter]
        return filtered_rows, filtered_cols

    def _create_perturbed_graph(self, attr, adj, perturbation, device=None):
        if device is None:
            device = attr.device

        valid_perturbation = perturbation[perturbation > 0]
        perturbation_rows, perturbation_cols = self._linear_to_triu_idx(self.n_nodes, valid_perturbation)
        
        filtered_rows, filtered_cols = self._local_adjacency_filter(
            perturbation_rows=perturbation_rows, perturbation_cols=perturbation_cols, device=device)


        pert_matrix = torch.sparse_coo_tensor(
            indices=torch.stack([
                torch.cat([filtered_rows, filtered_cols]), torch.cat([filtered_cols, filtered_rows])]).to(device),
                values=torch.ones(filtered_rows.shape[0] * 2).to(device),
                size=(self.n_nodes, self.n_nodes)).coalesce().to(device)
        adj_original = adj.to_torch_sparse_coo_tensor()
        pert_adj = ((adj_original + pert_matrix) - (adj_original * pert_matrix * 2)).coalesce()
        # import pdb; pdb.set_trace()
        return attr, pert_adj

class EvaTarget(EvaFast):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.attacking_nodes = kwargs.get("attacking_nodes", None)
        self.success_steps = -1
        print("eva target is called with target = ", self.attacking_nodes)

    def find_optimal_perturbation(self, n_perturbations: int, **kwargs):
        evaluate_ga = lambda x: self._evaluate_sparse_perturbation(
            attr=self.attr, adj=self.adj, labels=self.labels, 
            mask_attack=self.mask_attack, perturbation=x, model=self.model, device=self.device)
        
        problem = GraphEvalProblem(
            objective_sense="min", 
            objective_func=evaluate_ga,
            solution_length=n_perturbations,
            dtype=torch.int64,
            bounds=(0, (self.n_nodes * (self.n_nodes - 1)) // 2 - 1), device=self.device,
            capacity=self.capacity, attack_class=self
        )
        searcher = GeneticAlgorithm(
            problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
                                PositiveIntMutation(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)],
            popsize=self.num_population, re_evaluate=False
        )
        logger = StdOutLogger(searcher, interval=10)
        
        for i in range(self.n_steps):
            searcher.step()
            if searcher.status["pop_best"].evals[0] < 1e-3:
                print("Search finished at step", i)
                self.success_steps = i
                break
            
        # searcher.run(self.n_steps)
        
        best_perturbation = searcher.status["pop_best"].values

        return best_perturbation

    def bulk_evaluation(self, attr_list, adj_list, labels_list, model, mask_attack):
        if self.attacking_nodes is not None:
            idx_attack = torch.tensor(self.attacking_nodes)
        else:
            idx_attack = mask_attack.nonzero(as_tuple=True)[0]

        n_graphs = len(attr_list)

        idx_attack_bulk = idx_attack + (torch.arange(n_graphs).reshape(-1, 1) * attr_list[0].shape[0])
        idx_attack_bulk = idx_attack_bulk.reshape(-1)

        labels_bulk = torch.cat(labels_list, dim=0)[idx_attack_bulk]

        attr_bulk = torch.cat(attr_list, dim=0)
        adj_bulk = self.cat_block_sparse(adj_list, block_size=self.n_nodes)

        model.eval()
        pred = model(attr_bulk, adj_bulk)
        pred_selected = torch.softmax(pred[idx_attack_bulk], 1)
        
        labels_bulk = torch.cat(labels_list, dim=0)[idx_attack_bulk]
        labels_bulk_mask = F.one_hot(labels_bulk, pred.shape[1]).bool().to(pred.device)
        
        true_softmax = pred_selected[labels_bulk_mask]
        
        
        # hits = (pred == labels_bulk)
        # hits_per_instance = hits.view(len(attr_list), -1)
        # accs = hits_per_instance[:, mask_attack.nonzero(as_tuple=True)[0]].sum(dim=1) / mask_attack.sum()
        true_margin = (pred_selected.max(dim=1).values - true_softmax).reshape((-1, idx_attack.shape[0],))
        softmaxes = pred_selected.max(dim=1).values.reshape((-1, idx_attack.shape[0],))
        try:
            loss = torch.tensor([
                (softmaxes[i])[true_margin[i] < 1e-3].mean() if (true_margin[i] < 1e-3).sum() > 0 else 0
                for i in range(softmaxes.shape[0])
            ])
        except:
            loss = 0
        # print("average accuracy", accs.mean())
        

        return loss





class EvaSGCPoisoningAttack(EvaAttack):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.capacity = kwargs.get("capacity", 1)
        self.training_idx= kwargs.get("training_idx", 1)
        print("EvaSGCPoisoningAttack called with capacity =", self.capacity)

    def _evaluate_sparse_perturbation(self, attr, adj, labels, mask_attack, perturbation, model, device=None):
        if device is None:
            device = attr.device

        pert_attr, pert_adj = self._create_perturbed_graph(attr, adj, perturbation)
        # metric_result = self.metric(attr=pert_attr, adj=pert_adj, labels=labels, 
        #                             model=model, mask_test=mask_attack, device=device)

        device ='cpu'
        device_2 = 'cuda'
        X = attr.to(device_2)
        y = labels.to(device).numpy()
        A = pert_adj.to(device_2)
        
        training_idx = self.training_idx.to(device)
        test_mask = mask_attack.to(device)

        A_hat = normalize_adj_torch(A.to_dense())
        # A_hat = torch.sparse.FloatTensor(torch.LongTensor([A_hat.row, A_hat.col]), torch.FloatTensor(A_hat.data), torch.Size(A_hat.shape))
        # A_hat = A_hat.to(device)
        A_square_full = A_hat@A_hat #propagation_matrix(A_hat) #use pie from PPNP propagation_matrix(A_hat) 
        A_square_X = A_square_full@X
        # tok =time()
        A_square_X_cpu = A_square_X.cpu()
        model =  LogisticRegression(random_state=123, C=10).fit((A_square_X_cpu)[training_idx], y[training_idx])
        y_pred = model.predict(A_square_X_cpu)

        metric_result= (y_pred[test_mask] == y[test_mask]).mean()

        return metric_result

