import torch
import torchmetrics
# FIXME: Removed unused Metric import
from typing import Dict, List

class EpiformerMetrics(torch.nn.Module):
    def __init__(self, epi_threshold=0.3, para_threshold=0.3, _edge_cutoff=3.39):
        super().__init__()
        """
        - Compute classification metrics for node classification task
        - Uses probabilities and true labels to compute AUC and AUPRC
        - For metrics like F1, precision, recall, accuracy, and MCC, converts
          probabilities to binary predictions using configurable thresholds
        -  edge cutoff for epitope classification via edge sum thresholding
        """
        self.epi_threshold = epi_threshold
        self.para_threshold = para_threshold
        self._edge_cutoff = _edge_cutoff
        
        # Node-level metrics for epitope
        self.epitope_metrics = torchmetrics.MetricCollection({
            'epitope_auc': torchmetrics.AUROC(task='binary'),
            'epitope_auprc': torchmetrics.AveragePrecision(task='binary'),
            'epitope_f1': torchmetrics.F1Score(task='binary', threshold=epi_threshold),
            'epitope_precision': torchmetrics.Precision(task='binary', threshold=epi_threshold),
            'epitope_recall': torchmetrics.Recall(task='binary', threshold=epi_threshold),
            'epitope_accuracy': torchmetrics.Accuracy(task='binary', threshold=epi_threshold),
            'epitope_mcc': torchmetrics.MatthewsCorrCoef(task='binary', num_classes=2, threshold=epi_threshold),
        })
        
        self.epitope_confmat = torchmetrics.ConfusionMatrix(task='binary', num_classes=2, threshold=epi_threshold)
        
        # Node-level metrics for paratope
        self.paratope_metrics = torchmetrics.MetricCollection({
            'paratope_auc': torchmetrics.AUROC(task='binary'),
            'paratope_auprc': torchmetrics.AveragePrecision(task='binary'),
            'paratope_f1': torchmetrics.F1Score(task='binary', threshold=para_threshold),
            'paratope_precision': torchmetrics.Precision(task='binary', threshold=para_threshold),
            'paratope_recall': torchmetrics.Recall(task='binary', threshold=para_threshold),
            'paratope_accuracy': torchmetrics.Accuracy(task='binary', threshold=para_threshold),
            'paratope_mcc': torchmetrics.MatthewsCorrCoef(task='binary', num_classes=2, threshold=para_threshold),
        })
        
        self.paratope_confmat = torchmetrics.ConfusionMatrix(task='binary', num_classes=2, threshold=para_threshold)
        
        # Edge-level metrics for interaction prediction (fixed 0.3 threshold)
        self.edge_metrics = torchmetrics.MetricCollection({
            'edge_auc': torchmetrics.AUROC(task='binary'),
            'edge_auprc': torchmetrics.AveragePrecision(task='binary'),
            'edge_f1': torchmetrics.F1Score(task='binary', threshold=epi_threshold),
            'edge_precision': torchmetrics.Precision(task='binary', threshold=epi_threshold),
            'edge_recall': torchmetrics.Recall(task='binary', threshold=epi_threshold),
            'edge_accuracy': torchmetrics.Accuracy(task='binary', threshold=epi_threshold),
        })
        self.edge_confmat = torchmetrics.ConfusionMatrix(task='binary', num_classes=2, threshold=epi_threshold)
        
        #  FIX: Add per-complex metric storage
        self.per_complex_epitope_metrics = []
        self.per_complex_paratope_metrics = []
        
        # Initialize all metrics
        # reset metrics when we resume training, torchmetrics loads the initial thresholds by default
        self.reset()

    def to(self, device):
        self.epitope_metrics.to(device)
        self.paratope_metrics.to(device)
        self.edge_metrics.to(device)
        # TODO: Move confusion matrix to device
        self.epitope_confmat.to(device)
        self.paratope_confmat.to(device)
        self.edge_confmat.to(device)
        return self

    # TODO: [ FIX] Implement per-complex MCC calculation as per  approach
    # Current implementation computes global MCC which is biased by large complexes
    def update(self, outputs, batch):
        # Update epitope metrics with probabilities
        self.epitope_metrics.update(
            outputs['epitope_prob'], 
            batch['ag_res'].y.long()
        )
        
        # Update paratope metrics with probabilities
        self.paratope_metrics.update(
            outputs['paratope_prob'], 
            batch['ab_res'].y.long()
        )
        
        # Update confusion matrices with probabilities
        self.epitope_confmat.update(
            outputs['epitope_prob'], 
            batch['ag_res'].y.long()
        )
        self.paratope_confmat.update(
            outputs['paratope_prob'], 
            batch['ab_res'].y.long()
        )
        
        # Update edge metrics
        edge_preds, edge_labels = self._extract_edge_data_from_batch(outputs, batch)
        if edge_preds is not None and edge_labels is not None:
            self.edge_metrics.update(edge_preds, edge_labels.long())
            self.edge_confmat.update(edge_preds, edge_labels.long())
        
        #  FIX: Compute per-complex metrics using PyG batch tensor
        self._update_per_complex_metrics(outputs, batch)
    
    

    def compute(self) -> Dict[str, torch.Tensor]:
        metrics = {}
        metrics.update(self.epitope_metrics.compute())
        metrics.update(self.paratope_metrics.compute())
        
        # Only compute edge metrics if data was collected
        try:
            metrics.update(self.edge_metrics.compute())
        except ValueError:
            # No edge data collected - skip edge metrics
            pass
        
        # Extract confusion matrix components (TP, FP, TN, FN) for epitope, paratope, and edges
        # Confusion matrix format: [[TN, FP], [FN, TP]]
        epitope_cm = self.epitope_confmat.compute()
        paratope_cm = self.paratope_confmat.compute()
        
        # Extract individual components for logging and tracking
        metrics['epitope_tn'] = epitope_cm[0, 0].float()  # True Negatives
        metrics['epitope_fp'] = epitope_cm[0, 1].float()  # False Positives  
        metrics['epitope_fn'] = epitope_cm[1, 0].float()  # False Negatives
        metrics['epitope_tp'] = epitope_cm[1, 1].float()  # True Positives
        
        metrics['paratope_tn'] = paratope_cm[0, 0].float()  # True Negatives
        metrics['paratope_fp'] = paratope_cm[0, 1].float()  # False Positives
        metrics['paratope_fn'] = paratope_cm[1, 0].float()  # False Negatives  
        metrics['paratope_tp'] = paratope_cm[1, 1].float()  # True Positives
        
        # Edge confusion matrix components (only if edge data was collected)
        try:
            edge_cm = self.edge_confmat.compute()
            edge_tp = edge_cm[1, 1].float()  # True Positives
            edge_fp = edge_cm[0, 1].float()  # False Positives
            edge_fn = edge_cm[1, 0].float()  # False Negatives
            edge_tn = edge_cm[0, 0].float()  # True Negatives
            
            metrics['edge_tn'] = edge_tn
            metrics['edge_fp'] = edge_fp
            metrics['edge_fn'] = edge_fn
            metrics['edge_tp'] = edge_tp
            
            # Manual MCC computation to avoid torchmetrics bug
            eps = 1e-8
            mcc_denom = torch.sqrt((edge_tp + edge_fp + eps) * (edge_tp + edge_fn + eps) * 
                                  (edge_tn + edge_fp + eps) * (edge_tn + edge_fn + eps))
            if mcc_denom > eps:
                edge_mcc = (edge_tp * edge_tn - edge_fp * edge_fn) / mcc_denom
            else:
                edge_mcc = torch.tensor(0.0)
            metrics['edge_mcc'] = edge_mcc
            
        except ValueError:
            # No edge data collected - skip edge confusion matrix
            pass
  
        
        return metrics

    def reset(self):
        self.epitope_metrics.reset()
        self.paratope_metrics.reset()
        self.edge_metrics.reset()
        self.epitope_confmat.reset()
        self.paratope_confmat.reset()
        self.edge_confmat.reset()
        #  FIX: Reset per-complex storage
        self.per_complex_epitope_metrics.clear()
        self.per_complex_paratope_metrics.clear()
        # Reset -specific metrics storage


    def _extract_edge_data_from_batch(self, outputs, batch):
        """Extract flattened edge predictions and ground truth"""
        if 'interaction_matrix' not in outputs:
            return None, None
        
        device = outputs['interaction_matrix'].device
        ag_batch = batch['ag_res'].batch
        ab_batch = batch['ab_res'].batch
        edge_index = batch[('ag_res', 'interacts', 'ab_res')].edge_index
        
        all_edge_preds = []
        all_edge_labels = []
        
        batch_size = int(ag_batch.max().item()) + 1 if ag_batch.numel() > 0 else 1
        
        for i in range(batch_size):
            ag_indices = torch.where(ag_batch == i)[0]
            ab_indices = torch.where(ab_batch == i)[0]
            
            # Extract predicted submatrix
            ag_grid, ab_grid = torch.meshgrid(ag_indices, ab_indices, indexing='ij')
            pred_submatrix = outputs['interaction_matrix'][ag_grid, ab_grid]
            
            # Build ground truth adjacency
            adj = torch.zeros_like(pred_submatrix, device=device)
            if edge_index.numel() > 0:
                ag_edges_mask = torch.isin(edge_index[0], ag_indices)
                ab_edges_mask = torch.isin(edge_index[1], ab_indices)
                valid_edges_mask = ag_edges_mask & ab_edges_mask
                
                if valid_edges_mask.any():
                    local_edges = edge_index[:, valid_edges_mask]
                    ag_global_to_local = {g.item(): l for l, g in enumerate(ag_indices)}
                    ab_global_to_local = {g.item(): l for l, g in enumerate(ab_indices)}
                    
                    for e in range(local_edges.shape[1]):
                        ag_g = local_edges[0, e].item()
                        ab_g = local_edges[1, e].item()
                        if ag_g in ag_global_to_local and ab_g in ab_global_to_local:
                            adj[ag_global_to_local[ag_g], ab_global_to_local[ab_g]] = 1.0
            
            all_edge_preds.append(pred_submatrix.flatten())
            all_edge_labels.append(adj.flatten())
        
        return torch.cat(all_edge_preds) if all_edge_preds else None, \
               torch.cat(all_edge_labels) if all_edge_labels else None

    
