"""
Base interface for ablation study methods.

All methods (Differentiable Coherent, Hard Baseline, Hashimoto, Boosted)
implement this interface for fair comparison.
"""

from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple
import torch
import numpy as np


class BaseMethod(ABC):
    """Abstract base class for factuality methods."""

    def __init__(self, config: Dict[str, Any]):
        """
        Initialize method with configuration.

        Args:
            config: Method-specific configuration dict
        """
        self.config = config
        self.name = self.__class__.__name__

    @abstractmethod
    def calibrate(
        self,
        X_cal: List[Dict[str, Any]],
        Y_cal: List[torch.Tensor],
        noise_cal: List[float],
        alpha: float,
        X_train: List[Dict[str, Any]] = None,
        Y_train: List[torch.Tensor] = None,
        noise_train: List[float] = None,
        cal_indices: List[int] = None,
        noise_dict: Dict[int, float] = None
    ) -> Any:
        """
        Calibrate the method on calibration data.

        Args:
            X_cal: List of graph examples (dicts with 'features', 'adj', 'ancestors')
            Y_cal: List of label tensors [n] for each example
            noise_cal: List of noise values for each example
            alpha: Desired miscoverage level (e.g., 0.1 for 90% coverage)
            X_train: Optional training data for methods that need to train models
            Y_train: Optional training labels
            noise_train: Optional training noise
            cal_indices: Optional list of dataset indices for hard conformal evaluation
            noise_dict: Optional global noise dictionary for hard conformal evaluation

        Returns:
            threshold: Calibrated threshold (method-specific format)
        """
        pass

    @abstractmethod
    def predict(
        self,
        X_test: List[Dict[str, Any]],
        noise_test: List[float],
        threshold: Any,
        alpha: float = None,
        test_indices: List[int] = None
    ) -> List[torch.Tensor]:
        """
        Make predictions on test data using calibrated threshold.

        Args:
            X_test: List of graph examples
            noise_test: List of noise values
            threshold: Calibrated threshold from calibrate()
            alpha: Alpha value (optional, for methods that need alpha-specific hyperparams)
            test_indices: Optional list of dataset indices for hard conformal evaluation

        Returns:
            predictions: List of prediction tensors [n] (probabilities in [0,1])
        """
        pass

    def evaluate(
        self,
        predictions: List[torch.Tensor],
        Y_test: List[torch.Tensor],
        test_indices: List[int] = None,
        dataset = None
    ) -> Dict[str, float]:
        """
        Evaluate predictions against ground truth.

        Args:
            predictions: List of prediction tensors from predict()
            Y_test: List of ground truth labels
            test_indices: Optional list of dataset indices for graph_annotations lookup
            dataset: Optional dataset object for graph_annotations lookup

        Returns:
            metrics: Dict with keys:
                - 'coverage': Fraction of examples with valid predictions
                - 'marginal_coverage': Average coverage across all claims
                - 'avg_retention': Average % of claims retained
                - 'true_retention': Average % of TRUE claims retained
        """
        metrics = {
            'coverage': 0.0,
            'marginal_coverage': 0.0,
            'avg_retention': 0.0,
            'true_retention': 0.0,
            'n_examples': len(predictions)
        }

        total_covered = 0
        total_claims = 0
        total_true_claims = 0
        retained_claims = 0
        retained_true_claims = 0

        for i, (pred, y_true) in enumerate(zip(predictions, Y_test)):
            pred = pred.detach() if isinstance(pred, torch.Tensor) else torch.tensor(pred)
            y_true = y_true.detach() if isinstance(y_true, torch.Tensor) else torch.tensor(y_true)

            # Convert to numpy for easier computation
            pred_np = pred.cpu().numpy() if pred.is_cuda else pred.numpy()
            y_true_np = y_true.cpu().numpy() if y_true.is_cuda else y_true.numpy()

            n_claims = len(y_true_np)
            n_true = y_true_np.sum()

            # Binary retention (threshold at 0.5 for probabilistic predictions)
            retained = (pred_np > 0.5).astype(int)

            total_claims += n_claims
            total_true_claims += n_true
            retained_claims += retained.sum()
            retained_true_claims += (retained * y_true_np).sum()

            # Coverage: check if prediction is a valid subgraph
            if test_indices is not None and dataset is not None:
                # Use graph_annotations to check validity
                idx = test_indices[i]
                question = dataset.raw_data['data'][idx]
                valid_subgraphs = question.get('graph_annotations', {}).get('y', [])

                # Convert prediction to list for comparison
                pred_subgraph = retained.tolist()

                if pred_subgraph in valid_subgraphs:
                    total_covered += 1
            else:
                # Fallback: check if all false claims are rejected
                false_claims = (y_true_np == 0)
                if false_claims.any():
                    all_false_rejected = np.all(pred_np[false_claims] <= 0.5)
                    total_covered += int(all_false_rejected)

        n_examples = len(predictions)
        metrics['coverage'] = total_covered / n_examples if n_examples > 0 else 0.0
        # Marginal coverage: fraction of TRUE claims retained (should be ~1-alpha)
        metrics['marginal_coverage'] = retained_true_claims / total_true_claims if total_true_claims > 0 else 0.0
        metrics['avg_retention'] = retained_claims / total_claims if total_claims > 0 else 0.0
        metrics['true_retention'] = retained_true_claims / total_true_claims if total_true_claims > 0 else 0.0

        return metrics

    def __str__(self) -> str:
        return self.name
