from copy import deepcopy

from sympy import im
import torch
from evotorch.core import Problem, SolutionBatch
from evotorch.operators.base import CopyingOperator

from eva.utils import triangular_column_indices, triangular_matrix_index
from eva.utils import linear_to_triu_idx, triu_to_linear_idx
from torch_geometric.utils import degree

from tqdm import tqdm
import numpy as np


class GraphMutation(CopyingOperator):
    def __init__(self, problem, **kwargs):
        super().__init__(problem)
    
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        raise NotImplementedError
    
    def update_locals(self, population, **kwargs):
        pass

class LocalBudgetMutation(GraphMutation):
    def __init__(self, problem, delta=0.5, adj=None, n_nodes=None, device=None, idx_attack=None, **kwargs):

        super().__init__(problem)

        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.delta = delta
        self.adj = adj.coalesce()
        self.n_nodes = n_nodes or self.adj.shape[0]
        edge_index = self.adj.indices()[:, self.adj.values() != 0]
        self.orig_degrees = degree(edge_index[0, :], num_nodes=self.n_nodes)
        self.local_budget = self.delta * self.orig_degrees
        if idx_attack is not None:
            self.idx_attack = idx_attack.to(self.device)
        else:
            self.idx_attack = torch.arange(self.n_nodes).to(self.device)

    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        perturbations = result.access_values()
        corrected_perturbations = self._local_adjacency_filter(perturbations)
        result.set_values(corrected_perturbations)
        return result

    def _local_adjacency_filter(self, perturbation):
        device = perturbation.device
        valid_perturbations = perturbation.clone()
        valid_perturbations[valid_perturbations < 0] = 0
        perturbation_rows, perturbation_cols = linear_to_triu_idx(self.n_nodes, valid_perturbations)

        icall = 0
        while True:
            icall += 1
            pert_degrees = torch.stack(
                [degree(perturbation_rows[i], num_nodes=self.n_nodes) for i in range(perturbation.shape[0])]
                ) + torch.stack(
                    [degree(perturbation_cols[i], num_nodes=self.n_nodes) for i in range(perturbation.shape[0])]
                )
            violations = (pert_degrees - self.local_budget)
            violations[violations < 0] = 0
            violations[violations > 0] = violations[violations > 0].ceil()

            if violations.sum() <= 0:
                break

            violations = violations.long()
            
            for i in range(perturbation.shape[0]):
                violation_nodes = violations[i].nonzero(as_tuple=True)[0]
                if violation_nodes.shape[0] == 0:
                    continue
                n_violation_nodes = violations[i][violation_nodes]
                coincidence = (
                    perturbation_rows[i] == violation_nodes.reshape(-1, 1)
                    ) | (perturbation_cols[i] == violation_nodes.reshape(-1, 1))

                removing_index = torch.cat([torch.tensor([0]).to(device), coincidence.sum(1).cumsum(0)[:-1]])
                # filter out
                population_index = coincidence.nonzero()[removing_index][:, 1]

                all_available_nodes = torch.arange(self.n_nodes).to(device)[violations[i] == 0]
                all_available_idx_attack = self.idx_attack[violations[i][self.idx_attack] == 0]


                perturbation_rows[i, population_index] = all_available_idx_attack[torch.randint(0, all_available_idx_attack.shape[0], (population_index.shape[0],)).to(device)]
                perturbation_cols[i, population_index] = all_available_nodes[torch.randint(0, all_available_nodes.shape[0], (population_index.shape[0],)).to(device)]

        if icall > 30:
            print("called: ", icall)

        swap_idx = perturbation_rows > perturbation_cols
        perturbation_rows[swap_idx], perturbation_cols[swap_idx] = perturbation_cols[swap_idx], perturbation_rows[swap_idx]
        
        new_perturbation = triu_to_linear_idx(self.n_nodes, perturbation_rows, perturbation_cols)
        return new_perturbation

class PositiveIntMutation(GraphMutation):
    def __init__(self, problem, mutation_rate=0.1, toggle_rate=0.0):
        super().__init__(problem)
        self.mutation_rate = mutation_rate
        self.toggle_rate = toggle_rate
    
    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        data = result.access_values()
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]
        toggle_mutations = torch.rand(size=mutant_data.shape, device=mutant_data.device) < self.toggle_rate
        new_vals = torch.randint(0, self.problem.upper_bounds, size=mutant_data.shape, device=data.device)
        new_vals[(mutant_data >= 0) & toggle_mutations] = -1
        new_vals[(mutant_data < 0) & (~toggle_mutations)] = -1
        # print("mutation number", mutation_mask.sum().item() / len(batch))
        data[mutation_mask] = new_vals
        # TODO: Leave some part of the population unmutated
        return result

    def update_locals(self, population, **kwargs):
        pass


class VarinacePreservingMutation(PositiveIntMutation):
    def __init__(self, problem, mutation_rate=0.1, variance_coef=0.1, radius=0.5):
        super().__init__(problem, mutation_rate=mutation_rate)

        self.initial_mutation_rate = mutation_rate
        self.last_variance = -1
        self.population_var = -1
        self.radius = radius
        self.variance_coef = variance_coef

    def update_locals(self, population):
        evals = population.access_evals()
        self.last_variance = self.population_var
        self.population_var = evals.var()

        if self.last_variance == -1:
            return
        if self.population_var <= 1e-6:
            self.mutation_rate = self.initial_mutation_rate + self.initial_mutation_rate * self.radius
            # print(self.mutation_rate)
            return

        self.mutation_rate = self.mutation_rate + (self.population_var - self.last_variance) / self.last_variance * self.variance_coef

        self.mutation_rate = torch.clamp(self.mutation_rate, 
                                         self.initial_mutation_rate - self.initial_mutation_rate * self.radius, 
                                         self.initial_mutation_rate + self.initial_mutation_rate * self.radius)
        if self.mutation_rate < 0:
            self.mutation_rate = 0
        # print(self.mutation_rate)
        # print("I am here")

    def statistics(self, population):
        return {"mutation_rate": self.mutation_rate, "variance": self.population_var}
    


class AdaptiveIntMutation(CopyingOperator):
    def __init__(self, problem, mutation_rate=0.1, toggle_rate=0.5, patience=10):
        super().__init__(problem)
        self.mutation_rate = mutation_rate
        self.initial_mutation_rate = mutation_rate
        self.patience = patience
        self.toggle_rate = toggle_rate

        self.min_mutation_rate = mutation_rate / 3
        self.max_mutation_rate = mutation_rate * 3

        self.variances = []
        self.second_previous_variance = -1
        self.variance_derivative = 0
        self.eta = -1 * mutation_rate / 10
    
    def set_variance(self, var):
        self.variances.append(var)

        if len(self.variances) < 10:
            return
        
        var_coef = var - torch.tensor(self.variances[-10]).mean()
        
        self.mutation_rate = self.mutation_rate + (var_coef / self.variances[-10].mean()) * self.mutation_rate
        self.mutation_rate = (self.mutation_rate).item()
        
        self.mutation_rate = min(self.mutation_rate, self.max_mutation_rate)
        self.mutation_rate = max(self.mutation_rate, self.min_mutation_rate)

        if not ((self.mutation_rate <= 0) | (self.mutation_rate >= 0)):
            # number is nan
            self.mutation_rate = self.initial_mutation_rate
    
    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        data = result.access_values()
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]
        toggle_mutations = torch.rand(size=mutant_data.shape, device=mutant_data.device) < self.toggle_rate
        new_vals = torch.randint(0, self.problem.upper_bounds, size=mutant_data.shape, device=data.device)
        new_vals[(mutant_data >= 0) & toggle_mutations] = -1
        new_vals[(mutant_data < 0) & (~toggle_mutations)] = -1
        # print("mutation number", mutation_mask.sum().item() / len(batch))
        data[mutation_mask] = new_vals
        # TODO: Leave some part of the population unmutated
        return result
    
    



class IdxMutation(GraphMutation):
    def __init__(self, problem, idx_attack, n_nodes, mutation_rate=0.1, adversary=None):
        super().__init__(problem)
        self.mutation_rate = mutation_rate
        self.idx_attack =idx_attack.clone() if isinstance(idx_attack, torch.Tensor) else torch.tensor(idx_attack)
        self.n_nodes = n_nodes
        self.adversary = adversary
        self.initial_mutation_rate = mutation_rate
        self.last_variance= -1
        self.population_var = -1
        self.step=0

    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        data = result.access_values()
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]
        # new_pop = []
        # purm_idx = torch.randperm(self.idx_attack.shape[0])
        # for i in tqdm(list(self.idx_attack[purm_idx])):
        #     all_edge = triangular_matrix_index(self.n_nodes, [i])
        #     new_pop.extend(all_edge[0: min(100, len(all_edge))])
        #     all_edge = triangular_column_indices(self.n_nodes, i)
        #     new_pop.extend(all_edge[0: min(100, len(all_edge))])
        #     if len(new_pop) > mutant_data.shape[0]:
        #         break
        
        # if len(new_pop) < mutant_data.shape[0]:
        #     for i in tqdm(list(self.idx_attack[purm_idx])):
        #         all_edge = triangular_matrix_index(self.n_nodes, [i])
        #         new_pop.extend(all_edge)
        #         all_edge = triangular_column_indices(self.n_nodes, i)
        #         new_pop.extend(all_edge)
        #         if len(new_pop) > mutant_data.shape[0]:
        #             break
        #     if len(new_pop) < mutant_data.shape[0]:

        #         for i in tqdm(list(self.adversary.idx_attack)):
        #             all_edge = triangular_matrix_index(self.n_nodes, [i])
        #             new_pop.extend(all_edge)
        #             all_edge = triangular_column_indices(self.n_nodes, i)
        #             new_pop.extend(all_edge)
        #             if len(new_pop) > mutant_data.shape[0]:
        #                 break
     
        # new_idx = torch.randperm(len(new_pop))[:mutant_data.shape[0]]
        # new_pop = (torch.tensor(new_pop)[new_idx]).reshape(mutant_data.shape).to(data.device)
        # data[mutation_mask] = new_pop
        number_idx_attack = min(20, len(self.idx_attack))
        i_j_temp = torch.stack(
            [torch.tensor(self.idx_attack)[torch.randint(0, number_idx_attack, mutant_data.shape)],
             torch.randint(0, self.n_nodes, mutant_data.shape)]).permute(1,0)
        
        # import pdb; pdb.set_trace()
        mask_swap = (i_j_temp[:,0] > i_j_temp[:, 1])
        i_j_temp[mask_swap] = i_j_temp[mask_swap][:, [1, 0]]
        
        mask_equal = (i_j_temp[:,0] == i_j_temp[:, 1])
        i_j_temp[mask_equal] = torch.tensor([self.idx_attack[0], self.idx_attack[1]])
        new_pop = triu_to_linear_idx(self.n_nodes, i_j_temp[:, 0], i_j_temp[:, 1]).to(data.device)
        
        data[mutation_mask] = new_pop
        # TODO: Leave some part of the population unmutated
        return result

    def update_locals(self, population, **kwargs):
        best_perturbation = self.adversary.searcher.status["pop_best"].values
        attr_adversary, adj_adversary = self.adversary._create_perturbed_graph(self.adversary.attr, self.adversary.adj, best_perturbation)
        model = self.adversary.model
        logits = model(attr_adversary, adj_adversary)
        labels = self.adversary.labels
        preds = logits.max(1)[1].type_as(labels)
        self.idx_attack = self.adversary.idx_attack[torch.where((preds[self.adversary.idx_attack] == labels[self.adversary.idx_attack]))[0].cpu()]
        self.idx_attack = self.idx_attack[torch.randperm(len(self.idx_attack))]
        
        # evals = population.access_evals()
        # self.last_variance = self.population_var
        # self.population_var = evals.var()

        # if self.last_variance == -1:
        #     self.step=0
        #     return
        # if self.population_var <= 1e-6:
        #     self.step = self.step + 1
        #     if self.step == 20:
        #         self.mutation_rate = self.mutation_rate * 1.1
        #         self.step =0
        #     print(self.mutation_rate)
        # else:
        #     self.step = 0
            # return


class LocalIdxMutation(IdxMutation):
    def __init__(self, problem, idx_attack, n_nodes, mutation_rate=0.1, adversary=None):
        super().__init__(problem, idx_attack, n_nodes, mutation_rate, adversary)
        self.local_budget = self.adversary.local_budget.int()
        
    
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        data = result.access_values()
        self.idx_attack = torch.tensor(self.idx_attack).to(data.device)
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]

        data_rows, data_cols = linear_to_triu_idx(self.n_nodes, data)
        pert_degrees = torch.stack(
                [degree(data_rows[i], num_nodes=self.n_nodes) for i in range(data.shape[0])]
                ) + torch.stack(
                    [degree(data_cols[i], num_nodes=self.n_nodes) for i in range(data.shape[0])]
                )

        new_mutatnt_rows = []
        new_mutatnt_cols = []
        n_max_mutants = mutation_mask.sum(1).max()
        for i in range(data.shape[0]):
            violations = torch.relu(pert_degrees[i] - self.local_budget)
            available = (violations == 0)
            new_mutants_dest = torch.randint(0, available.sum(0), (n_max_mutants,)).to(data.device)
            
            idx_attack_available = self.idx_attack[available[self.idx_attack]]
            new_mutants_src = torch.randint(0, idx_attack_available.shape[0], (n_max_mutants,)).to(data.device)
    
             
            new_mutatnt_rows.append(idx_attack_available[new_mutants_src])
            new_mutatnt_cols.append(available.nonzero(as_tuple=True)[0][new_mutants_dest])

        new_mutatnt_rows = torch.stack(new_mutatnt_rows)
        new_mutatnt_cols = torch.stack(new_mutatnt_cols)
        
        swap_mask = new_mutatnt_rows > new_mutatnt_cols
        new_mutatnt_rows[swap_mask], new_mutatnt_cols[swap_mask] = new_mutatnt_cols[swap_mask], new_mutatnt_rows[swap_mask]

        sorted_mutation_mask = mutation_mask.long().sort(1, descending=True)[0].bool()[:, :n_max_mutants]
        data[mutation_mask] = triu_to_linear_idx(self.n_nodes, new_mutatnt_rows, new_mutatnt_cols)[sorted_mutation_mask]
        
        return result

        
        


class MarginMutation(GraphMutation):
    def __init__(self, problem, idx_attack, n_nodes, mutation_rate=0.1, adversary=None):
        super().__init__(problem)
        self.mutation_rate = mutation_rate
        self.idx_attack =idx_attack.clone()
        self.n_nodes = n_nodes
        self.adversary = adversary
    
    @torch.no_grad()
    def _do(self, batch: SolutionBatch) -> SolutionBatch:
        result = deepcopy(batch)
        data = result.access_values()
        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate
        mutant_data = data[mutation_mask]
        
        number_idx_attack = max(10, len(self.idx_attack)//10)

        i_j_temp = torch.stack([torch.tensor(self.idx_attack)[torch.randint(0, number_idx_attack, mutant_data.shape)], torch.randint(0, self.n_nodes, mutant_data.shape)]).permute(1,0)
        
        # import pdb; pdb.set_trace()
        mask_swap = (i_j_temp[:,0] > i_j_temp[:, 1])
        i_j_temp[mask_swap] = i_j_temp[mask_swap][:, [1, 0]]
        
        mask_equal = (i_j_temp[:,0] == i_j_temp[:, 1])
        i_j_temp[mask_equal] = torch.tensor([self.idx_attack[0], self.idx_attack[1]])
        new_pop = triu_to_linear_idx(self.n_nodes, i_j_temp[:, 0], i_j_temp[:, 1]).to(data.device)
        
        data[mutation_mask] = new_pop
        # #TODO fix this mutation and make it faster
        # new_pop = []
        # purm_idx = torch.randperm(self.idx_attack.shape[0])
        # for i in tqdm(list(self.idx_attack[purm_idx])):
        #     all_edge = triangular_matrix_index(self.n_nodes, [i])
        #     new_pop.extend(all_edge[0: min(100, len(all_edge))])
        #     all_edge = triangular_column_indices(self.n_nodes, i)
        #     new_pop.extend(all_edge[0: min(100, len(all_edge))])
        #     if len(new_pop) > mutant_data.shape[0]:
        #         break
        
        # if len(new_pop) < mutant_data.shape[0]:
        #     for i in tqdm(list(self.idx_attack[purm_idx])):
        #         all_edge = triangular_matrix_index(self.n_nodes, [i])
        #         new_pop.extend(all_edge)
        #         all_edge = triangular_column_indices(self.n_nodes, i)
        #         new_pop.extend(all_edge)
        #         if len(new_pop) > mutant_data.shape[0]:
        #             break
        #     if len(new_pop) < mutant_data.shape[0]:

        #         for i in tqdm(list(self.adversary.idx_attack)):
        #             all_edge = triangular_matrix_index(self.n_nodes, [i])
        #             new_pop.extend(all_edge)
        #             all_edge = triangular_column_indices(self.n_nodes, i)
        #             new_pop.extend(all_edge)
        #             if len(new_pop) > mutant_data.shape[0]:
        #                 break
     
        # new_idx = torch.randperm(len(new_pop))[:mutant_data.shape[0]]
        # new_pop = (torch.tensor(new_pop)[new_idx]).reshape(mutant_data.shape).to(data.device)
        # data[mutation_mask] = new_pop
        
        return result

    def update_locals(self, population, **kwargs):
        best_perturbation = self.adversary.searcher.status["pop_best"].values
        attr_adversary, adj_adversary = self.adversary._create_perturbed_graph(self.adversary.attr, self.adversary.adj, best_perturbation)
        model = self.adversary.model
        logits = model(attr_adversary, adj_adversary)
        labels = self.adversary.labels

        sorted = logits.argsort(-1)
        best_non_target_class = sorted[sorted != labels[:, None]].reshape(logits.size(0), -1)[:, -1]
        margin = (
            logits[np.arange(logits.size(0)), labels]
            - logits[np.arange(logits.size(0)), best_non_target_class]
        )
        
        margin = margin[self.adversary.idx_attack]
        # num_neg = torch.sum(margin < 0)
        # margin[margin <0] = torch.inf
        num_neg = torch.sum(margin < 0)
        margin[margin <0] = torch.max(margin)
        sorted_margin_args = margin.argsort().cpu().numpy()
        self.idx_attack = self.adversary.idx_attack[sorted_margin_args][0: min(300, len(margin)-num_neg)]
        
        # self.idx_attack = self.adversary.idx_attack[margin.cpu() >=0]
        # self.idx_attack = self.idx_attack[:-num_neg]
        # probabilities = (margin[margin.argsort()][:-num_neg])
        # samples = torch.multinomial(torch.ones_like(probabilities), len(self.idx_attack), replacement=False)
        # self.idx_attack = self.adversary.idx_attack[samples.cpu().numpy()]
        
