from task.BaseTrainer import BaseTrainer
from sklearn.metrics import f1_score,roc_auc_score
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling
import torch
class GIFTrainer(BaseTrainer):
    """
    GIFTrainer class for training and evaluating GNNs in preparation for applying the Graph Influence Function (GIF) unlearning method.

    This class extends the BaseTrainer and provides functionalities specific to the 
    Graph Influence Function (GIF) unlearning method. It handles the evaluation of 
    unlearning processes for different downstream tasks such as node classification, 
    edge prediction, and graph classification. The class updates model parameters, 
    performs evaluations, and computes relevant metrics like F1 score and AUC.
    
    Class Attributes:
        args (dict): Configuration parameters, including model type, dataset specifications, 
                     training hyperparameters, unlearning settings, and other relevant settings.
    
        logger (logging.Logger): Logger object used to log evaluation progress, metrics, 
                                 and other important information.
    
        model (torch.nn.Module): The neural network model to be evaluated and unlearned.
    
        data (torch_geometric.data.Data): The dataset containing graph information for training, 
                                         validation, and testing.
    """
    def __init__(self, args, logger, model, data):
        """
        Initializes the GIFTrainer with the provided configuration, logger, model, and data.
    
        Args:
            args (dict): Configuration parameters, including model type, dataset specifications, 
                        training hyperparameters, unlearning settings, and other relevant settings.
                            
            logger (logging.Logger): Logger object used to log evaluation progress, metrics, 
                                     and other important information.
                            
            model (torch.nn.Module): The neural network model to be evaluated and unlearned.
                            
            data (torch_geometric.data.Data): The dataset containing graph information for training, 
                                             validation, and testing.
        """
        super().__init__(args, logger, model, data)
        

    
    def eval_unlearn(self, new_parameters):
        """
        Evaluates the F1 score after applying Graph Influence Function (GIF) based unlearning.

        This method updates the model's parameters with `new_parameters`, typically generated by 
        the GIF unlearning process. It performs a forward pass to obtain predictions and calculates 
        the F1 score based on the specified downstream task, which can be node classification, 
        edge prediction, or graph classification.

        Args:
            new_parameters (list of torch.Tensor or torch.Tensor): 
                New model parameters to replace the current ones. These parameters should align with the 
                model's parameter structure and are typically produced by the GIF unlearning method.

        Returns:
            float: 
                The F1 score of the model after applying the unlearning process.
        """
        idx = 0
        for p in self.model.parameters():
            p.data = new_parameters[idx]
            idx = idx + 1

        out = self.model.reason_once_unlearn(self.data)
        if self.args["downstream_task"]=="node":
            test_f1 = f1_score(
                self.data.y[self.data['test_mask']].cpu().numpy(),
                out[self.data['test_mask']].argmax(axis=1).cpu().numpy(),
                average="micro"
            )
        elif self.args["downstream_task"]=="edge":
            test_f1 = self.eval_unlearn_edge(out)
        elif self.args["downstream_task"]=="graph":
            test_f1 = self.evaluate_graph_model()
        return test_f1
    
    def eval_unlearn_edge(self,out):
        """
        Evaluates edge prediction performance after unlearning using Graph Influence Function (GIF).

        This method performs negative sampling on the test edges, decodes the logits for both 
        positive and negative edges, and computes the Area Under the ROC Curve (AUC) score 
        to assess the model's performance in distinguishing between existing and non-existing edges.

        Args:
            out (torch.Tensor): Model outputs or node embeddings after unlearning.

        Returns:
            float: 
                The AUC score for edge prediction after unlearning.
        """
        neg_edge_index = negative_sampling(
            edge_index=self.data.test_edge_index,num_nodes=self.data.num_nodes,
            num_neg_samples=self.data.test_edge_index.size(1)
        )

        edge_pred_logits = self.decode(z=out, pos_edge_index=self.data.test_edge_index,neg_edge_index=neg_edge_index).sigmoid()

        edge_pred = torch.where(edge_pred_logits > 0.5, torch.tensor(1), torch.tensor(0))
        edge_pred = edge_pred_logits.cpu()

        pos_edge_labels = torch.ones(self.data.test_edge_index.size(1),dtype=torch.float32)
        neg_edge_labels = torch.zeros(neg_edge_index.size(1),dtype=torch.float32)
        edge_labels = torch.cat((pos_edge_labels,neg_edge_labels))
        AUC_score = roc_auc_score(edge_labels.detach().cpu(), edge_pred.detach().cpu())
        return AUC_score
    
    def get_loss(self, out, reduction="none"):
        """
        Computes the binary cross-entropy loss with logits for edge prediction.

        This method generates negative samples, decodes the logits for both positive and 
        negative edges, concatenates the labels, and computes the binary cross-entropy 
        loss between the predicted logits and true labels.

        Args:
            out (torch.Tensor): Model outputs or node embeddings.
                
            reduction (str, optional): Specifies the reduction to apply to the output: 
                                       'none' | 'mean' | 'sum'. Defaults to "none".

        Returns:
            torch.Tensor: 
                The computed binary cross-entropy loss.
        """
        neg_edge_index = negative_sampling(
                edge_index=self.data.edge_index,num_nodes=self.data.num_nodes,
                num_neg_samples=self.data.edge_index.size(1)
            )
        neg_edge_label = torch.zeros(neg_edge_index.size(1), dtype=torch.float32)
        pos_edge_label = torch.ones(self.data.edge_index.size(1),dtype=torch.float32)
        edge_logits = self.decode(z=out, pos_edge_index=self.data.edge_index,neg_edge_index=neg_edge_index)
        edge_labels = torch.cat((pos_edge_label,neg_edge_label),dim=-1)
        edge_labels = edge_labels.to(self.device)
        loss = F.binary_cross_entropy_with_logits(edge_logits, edge_labels, reduction=reduction)
        return loss
