import torch
import torch.nn.functional as F
from eva_attack import EvaAttack, PositiveIntMutation, AdaptiveIntMutation
from sparse_smoothing.prediction import sample_multiple_graphs
from tqdm import tqdm
from torch_sparse import SparseTensor

from evotorch import Problem
from evotorch.algorithms import GeneticAlgorithm
from evotorch import operators as evo_ops
from evotorch.logging import StdOutLogger
from copy import deepcopy



def smooth_model_p(test_attr, test_adj, labels, test_mask, model, smoothing_config, dataset_info, batch_size=50, device='cpu'):
    if isinstance(test_adj, SparseTensor):
        edge_idx = torch.stack([test_adj.coo()[0], test_adj.coo()[1]]).long().to(device)
    elif isinstance(test_adj, torch.Tensor):
        edge_idx = test_adj.indices().long().to(device)
    else:
        raise ValueError("Invalid test_adj type")
        
    attr_idx = torch.stack(list(test_attr.nonzero(as_tuple=True)))

    votes = torch.zeros((test_attr.shape[0], dataset_info["n_classes"])).to(device)

    for i in range(smoothing_config["n_samples"] // batch_size):
        attr_idx_batch, edge_idx_batch = sample_multiple_graphs(
            attr_idx=attr_idx, edge_idx=edge_idx,
            sample_config=smoothing_config, n=test_attr.shape[0], d=0, nsamples=batch_size)

        test_attr_bulk = test_attr.repeat(batch_size, 1).to(device)

        test_adj_bulk = torch.sparse_coo_tensor(indices=edge_idx_batch, values=torch.ones_like(edge_idx_batch[0]).float(), size=(test_attr_bulk.shape[0], test_attr_bulk.shape[0])).to(device)

        votes += F.one_hot(model(test_attr_bulk, test_adj_bulk).argmax(1), dataset_info["n_classes"]).float().reshape(-1, test_attr.shape[0], dataset_info["n_classes"]).sum(dim=0)

    y_true_mask = F.one_hot(labels).bool()
    p_emps = (votes / smoothing_config["n_samples"])[y_true_mask]
    p_emps_test = p_emps[test_mask]
    return p_emps_test

class EvaCert(EvaAttack):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Adaptive Mutation
        self.population_variance = -1
        self.objective_track = []

        # Certificate Configs
        self.p_emp = kwargs.get("p_emp", 0.7)
        self.n_samples = kwargs.get("n_samples", 1000)
        self.pf_plus_att = kwargs.get("pf_plus_att", 0)
        self.pf_minus_att = kwargs.get("pf_minus_att", 0)
        self.pf_plus_adj = kwargs.get("pf_plus_adj", 0.01)
        self.pf_minus_adj = kwargs.get("pf_minus_adj", 0.6)
        self.batch_size = kwargs.get("batch_size", 500)
        self.remove_only = kwargs.get("remove_only", False)

        self.target_mask = kwargs.get("target_mask", None)

        self.metric = self.certificate_metric

    def certificate_metric(self, attr, adj, model, labels, mask_test, perturbation_adj=None, device=None):
        if self.target_mask is None:
            target_mask = mask_test
        else:
            target_mask = self.target_mask

        if device is None:
            device = self.attr.device

        p_emps_test = smooth_model_p(
            attr, adj, labels, mask_test, model, 
            smoothing_config={
                "n_samples": self.n_samples,
                "pf_plus_att": self.pf_plus_att,
                "pf_minus_att": self.pf_minus_att,
                "pf_plus_adj": self.pf_plus_adj,
                "pf_minus_adj": self.pf_minus_adj,
            }, dataset_info={"n_classes": labels.max().item() + 1}, 
            batch_size=self.batch_size, device=device)
        error = (p_emps_test - self.p_emp)[target_mask]
        # print(error)
        error[error < 0] = 0
        return error.mean()

    def find_optimal_perturbation(self, n_perturbations: int, **kwargs):
        evaluate_ga = lambda x: self._evaluate_sparse_perturbation(
            attr=self.attr, adj=self.adj, labels=self.labels, 
            mask_attack=self.mask_attack, perturbation=x, model=self.model, device=self.device)
        

        # n_edges = self.adj.nnz()
        # if self.remove_only:    
        #     problem = Problem(
        #         "min", 
        #         objective_func=evaluate_ga,
        #         solution_length=n_perturbations,
        #         dtype=torch.int64,
        #         bounds=(0, n_edges), device=self.device
        #     )
        # else:
        problem = Problem(
            "min", 
            objective_func=evaluate_ga,
            solution_length=n_perturbations,
            dtype=torch.int64,
            bounds=(0, (self.n_nodes * (self.n_nodes - 1)) // 2 - 1), device=self.device
        )

        mutation_class = PositiveIntMutation if self.mutation_method == "uniform" else AdaptiveIntMutation
        mutation_operator = mutation_class(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)
        searcher = GeneticAlgorithm(
            problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
                                mutation_operator],
            popsize=self.num_population, re_evaluate=False
        )
        # searcher = GeneticAlgorithm(
        #     problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
        #                         PositiveIntMutation(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)],
        #     popsize=self.num_population, re_evaluate=False
        # )
        logger = StdOutLogger(searcher, interval=1)
        
        # searcher.run(self.n_steps)
        for i in range(self.n_steps):
            searcher.step()
            evals = searcher.population.access_evals().reshape((-1,))
            population_var = evals.var()
            if mutation_class == AdaptiveIntMutation:
                mutation_operator.set_variance(population_var)
            self.objective_track.append(evals.min())
            # import pdb; pdb.set_trace()

        best_perturbation = searcher.status["pop_best"].values
        return best_perturbation
    


class EvaCertGlobal(EvaAttack):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Adaptive Mutation
        self.population_variance = -1
        self.objective_track = []

        # Certificate Configs
        self.p_emp = kwargs.get("p_emp", 0.7)
        self.n_samples = kwargs.get("n_samples", 1000)
        self.pf_plus_att = kwargs.get("pf_plus_att", 0)
        self.pf_minus_att = kwargs.get("pf_minus_att", 0)
        self.pf_plus_adj = kwargs.get("pf_plus_adj", 0.01)
        self.pf_minus_adj = kwargs.get("pf_minus_adj", 0.6)
        self.batch_size = kwargs.get("batch_size", 500)
        self.remove_only = kwargs.get("remove_only", True)

        self.metric = self.certificate_metric

    def certificate_metric(self, attr, adj, model, labels, mask_test, perturbation_adj=None, device=None):
        if device is None:
            device = self.attr.device

        p_emps_test = smooth_model_p(
            attr, adj, labels, mask_test, model, 
            smoothing_config={
                "n_samples": self.n_samples,
                "pf_plus_att": self.pf_plus_att,
                "pf_minus_att": self.pf_minus_att,
                "pf_plus_adj": self.pf_plus_adj,
                "pf_minus_adj": self.pf_minus_adj,
            }, dataset_info={"n_classes": labels.max().item() + 1}, 
            batch_size=self.batch_size, device=device)
        
        error = (p_emps_test > self.p_emp).sum() / mask_test.sum()
        return error

    def find_optimal_perturbation(self, n_perturbations: int, **kwargs):
        evaluate_ga = lambda x: self._evaluate_sparse_perturbation(
            attr=self.attr, adj=self.adj, labels=self.labels, 
            mask_attack=self.mask_attack, perturbation=x, model=self.model, device=self.device)
        

        n_edges = self.adj.nnz()
        if self.remove_only:    
            problem = Problem(
                "min", 
                objective_func=evaluate_ga,
                solution_length=n_perturbations,
                dtype=torch.int64,
                bounds=(0, n_edges), device=self.device
            )
        else:
            problem = Problem(
                "min", 
                objective_func=evaluate_ga,
                solution_length=n_perturbations,
                dtype=torch.int64,
                bounds=(0, (self.n_nodes * (self.n_nodes - 1)) // 2 - 1), device=self.device
            )

        mutation_class = PositiveIntMutation if self.mutation_method == "uniform" else AdaptiveIntMutation
        mutation_operator = mutation_class(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)
        searcher = GeneticAlgorithm(
            problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
                                mutation_operator],
            popsize=self.num_population, re_evaluate=False
        )
        # searcher = GeneticAlgorithm(
        #     problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self.tournament_size, num_points=self.num_cross_over), 
        #                         PositiveIntMutation(problem, mutation_rate=self.mutation_rate, toggle_rate=self.mutation_toggle_rate)],
        #     popsize=self.num_population, re_evaluate=False
        # )
        logger = StdOutLogger(searcher, interval=1)
        
        # searcher.run(self.n_steps)
        for i in range(self.n_steps):
            searcher.step()
            evals = searcher.population.access_evals().reshape((-1,))
            population_var = evals.var()
            if mutation_class == AdaptiveIntMutation:
                mutation_operator.set_variance(population_var)
            self.objective_track.append(evals.min())

        best_perturbation = searcher.status["pop_best"].values
        return best_perturbation
    
    def _create_perturbed_graph(self, attr, adj, perturbation, device=None):
        if device is None:
            device = attr.device

        if self.remove_only:
            adj_torch = adj.to_torch_sparse_coo_tensor()
            adj_idx = adj_torch.coalesce().indices().to(device)
            perturbation_mask = torch.zeros_like(adj_idx[0]).bool().to(device)
            perturbation_mask[perturbation.unique()] = True
            import pdb; pdb.set_trace()
            perturbation_rows = adj_idx[0][perturbation_mask]
            perturbation_cols = adj_idx[1][perturbation_mask]
            perturbation_matrix = torch.sparse_coo_tensor(
                indices=torch.stack([perturbation_rows, perturbation_cols]), 
                values=torch.ones_like(perturbation_rows).float(), 
                size=(self.n_nodes, self.n_nodes)).coalesce().to(device)
            perturbation_matrix = (perturbation_matrix + perturbation_matrix.T).coalesce()
            perturbation_matrix = torch.sparse_coo_tensor(
                indices=perturbation_matrix.indices(), 
                values=torch.zeros_like(perturbation_matrix.values()), 
                size=(self.n_nodes, self.n_nodes)).coalesce().to(device)
        
        else:
            valid_perturbation = perturbation[perturbation > 0]

            perturbations_rows, perturbation_cols = self._linear_to_triu_idx(self.n_nodes, valid_perturbation)
            
            pert_matrix = torch.sparse_coo_tensor(
                indices=torch.stack([
                    torch.cat([perturbations_rows, perturbation_cols]), torch.cat([perturbation_cols, perturbations_rows])]).to(device),
                    values=torch.ones(perturbations_rows.shape[0] * 2).to(device),
                    size=(self.n_nodes, self.n_nodes)).coalesce().to(device)
        adj_original = adj.to_torch_sparse_coo_tensor()
        pert_adj = ((adj_original + pert_matrix) - (adj_original * pert_matrix * 2)).coalesce()
        import pdb; pdb.set_trace()
        return attr, pert_adj