"""
Implementations of all factuality methods for ablation study.

1. DifferentiableCoherent: Your method with learned weights
2. HardBaseline: Original hard coherent factuality
3. HashimotoIndependent: Independent factuality (single feature)
4. BoostedIndependent: Learned independent factuality (multiple features)
5. XGBoostAccuracy: XGBoost trained for accuracy with post-hoc conformal calibration
"""

import sys
import os

# Add parent directory to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Any, Tuple

from src.ablation.base_method import BaseMethod
from src.differentiable_conformal_factuality import calibrate, predict
from src.models import ForwardScorer, MLPClaimScorer
from src.utilities import compute_quantile, generate_kfold_splits, split_dataset


class DifferentiableCoherent(BaseMethod):
    """
    Your differentiable coherent factuality method.

    Uses graph structure + learned scorer + soft argmax prediction.
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.base_hyperparams = config['hyperparams']
        self.hyperparams_path = config.get('hyperparams_path', None)
        self.feature_cols = config['feature_cols']
        self.dataset = config.get('dataset', None)
        self.model = None

        # For convergence testing: separate training and calibration hyperparameters
        self.training_hyperparams = config.get('training_hyperparams', None)
        self.calibration_hyperparams = config.get('calibration_hyperparams', None)

        # Load alpha-specific hyperparameters if available
        self.alpha_hyperparams = {}
        if self.hyperparams_path:
            import json
            with open(self.hyperparams_path, 'r') as f:
                data = json.load(f)
                # Extract hyperparams for each alpha
                for alpha_str, alpha_data in data.get('results', {}).items():
                    alpha_val = float(alpha_str)
                    self.alpha_hyperparams[alpha_val] = alpha_data['hyperparams']

    def calibrate(self, X_cal, Y_cal, noise_cal, alpha, X_train=None, Y_train=None, noise_train=None, cal_indices=None, noise_dict=None):
        """Calibrate using hard conformal method with trained model."""
        from src.models import LogisticClaimScorer
        from src.training import Trainer
        from src.hard.non_conformity import r_score
        import torch.optim as optim
        import torch
        import numpy as np

        # Store cal_indices and noise_dict for prediction
        self.cal_indices = cal_indices
        self.noise_dict = noise_dict

        # Determine which hyperparameters to use
        # For convergence testing: use separate training and calibration hyperparameters
        # For normal use: use combined hyperparameters
        if self.training_hyperparams is not None and self.calibration_hyperparams is not None:
            # Convergence mode: separate hyperparameters
            training_hp = self.training_hyperparams.copy()
            calibration_hp = self.calibration_hyperparams.copy()
        else:
            # Normal mode: use base hyperparameters for both
            hyperparams = self.base_hyperparams.copy()
            if alpha in self.alpha_hyperparams:
                # Override with alpha-specific hyperparams
                hyperparams.update(self.alpha_hyperparams[alpha])

            training_hp = hyperparams.copy()
            calibration_hp = hyperparams.copy()

        # Ensure all required hyperparameters are present with defaults
        required_defaults = {
            'temp': 0.2,
            'beta': 20.0,
            'gamma': 1.0,
            'lambda_': 1.0,
            'violation_mode': 'exponential',
            'squash_temp': 1.0,
            'margin': 20.0
        }
        for hp_dict in [training_hp, calibration_hp]:
            for key, default_val in required_defaults.items():
                if key not in hp_dict:
                    hp_dict[key] = default_val

        # Train model if we have training data (matching hyperparameter tuning)
        if X_train is not None and len(X_train) > 0 and self.model is None:
            input_dim = len(self.feature_cols)
            self.model = LogisticClaimScorer(input_dim)

            # Training hyperparameters (use training_hp for model training)
            lr = training_hp.get('lr', 0.015)
            n_epochs = 100  # Match hyperparameter tuning

            optimizer = optim.Adam(self.model.parameters(), lr=lr)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

            # Create trainer with TRAINING hyperparameters
            trainer = Trainer(self.model, optimizer, scheduler, training_hp, debugger=None)

            # Prepare training data in the format expected by Trainer
            # We need to create fake indices for train/val split
            # Since we already have train and cal data split, we'll use:
            # - X_train for training
            # - X_cal for validation (to monitor training, not for final calibration)

            # Create a temporary combined dataset
            from src.reasonining_graph_dataset import Reasoning_Graph_Dataset
            import numpy as np

            # Create noise dict
            noise_dict_train = {i: noise_train[i] for i in range(len(noise_train))}
            noise_dict_cal = {i: noise_cal[i] for i in range(len(noise_cal))}
            noise_dict = {**noise_dict_train, **{i + len(noise_train): v for i, v in noise_dict_cal.items()}}

            # Create a minimal dataset wrapper
            class TempDataset:
                def __init__(self, X_data, Y_data):
                    self.x = X_data
                    self.y = Y_data
                    self.raw_data = {'data': [{} for _ in X_data]}  # Dummy raw data

                def __len__(self):
                    return len(self.x)

                def __getitem__(self, idx):
                    return self.x[idx], self.y[idx]

            temp_dataset = TempDataset(X_train + X_cal, Y_train + Y_cal)
            train_idx = list(range(len(X_train)))
            val_idx = list(range(len(X_train), len(X_train) + len(X_cal)))

            # Train the model
            trainer.fit(
                dataset=temp_dataset,
                train_idx=train_idx,
                val_idx=val_idx,
                noise_dict=noise_dict,
                n_epochs=n_epochs,
                alpha=alpha,
                cal_ratio=0.5,  # Used internally by trainer
                device='cpu',
                patience=10
            )
        else:
            # Fallback: use ForwardScorer if no training data
            if self.model is None:
                self.model = ForwardScorer(feature_index=0)

        # Use hard conformal calibration (matching hyperparameter tuning)
        # Inject learned scores into dataset
        if self.model is not None and self.dataset is not None and cal_indices is not None:
            self.model.eval()
            with torch.no_grad():
                for idx in cal_indices:
                    scores = self.model(self.dataset.x[idx]['features']).squeeze(-1).cpu().numpy()
                    scores_list = np.atleast_1d(scores).tolist()
                    for claim_idx, claim in enumerate(self.dataset.raw_data['data'][idx]['claims']):
                        claim['learned_scores'] = scores_list[claim_idx]
            score_type = 'learned_scores'
        else:
            score_type = 'frequency-score'

        # Compute calibration threshold using r_score (hard conformal)
        # Use CALIBRATION hyperparameters for this step
        cal_scores = []
        for idx in cal_indices:
            question = self.dataset.raw_data['data'][idx]
            score = r_score(question, self.noise_dict[idx], 'simult', 'graph',
                          beta=calibration_hp.get('beta_mix', 0.0), score_type=score_type)
            cal_scores.append(score)

        threshold = compute_quantile(cal_scores, alpha)

        # Store calibration hyperparameters for use in predict()
        self.current_calibration_hp = calibration_hp

        return threshold

    def predict(self, X_test, noise_test, threshold, alpha=None, test_indices=None):
        """Predict using hard conformal prediction (highest_risk_graph)."""
        from src.hard.non_conformity import highest_risk_graph

        # Use stored calibration hyperparameters if available (convergence mode)
        # Otherwise use base hyperparameters (normal mode)
        if hasattr(self, 'current_calibration_hp'):
            hyperparams = self.current_calibration_hp.copy()
        else:
            hyperparams = self.base_hyperparams.copy()
            if alpha is not None and alpha in self.alpha_hyperparams:
                hyperparams.update(self.alpha_hyperparams[alpha])

        # Inject learned scores into test set
        if self.model is not None and self.dataset is not None and test_indices is not None:
            self.model.eval()
            with torch.no_grad():
                for idx in test_indices:
                    scores = self.model(self.dataset.x[idx]['features']).squeeze(-1).cpu().numpy()
                    scores_list = np.atleast_1d(scores).tolist()
                    for claim_idx, claim in enumerate(self.dataset.raw_data['data'][idx]['claims']):
                        claim['learned_scores'] = scores_list[claim_idx]
            score_type = 'learned_scores'
        else:
            score_type = 'frequency-score'

        # Use hard conformal prediction (highest_risk_graph)
        predictions = []
        for idx in test_indices:
            question = self.dataset.raw_data['data'][idx]
            U_filt = highest_risk_graph(threshold, self.noise_dict[idx], question, 'simult',
                                       beta=hyperparams.get('beta_mix', 0.0), score_type=score_type)

            # Convert to binary prediction vector
            pred_vector = torch.zeros(len(question['claims']))
            for i in U_filt:
                pred_vector[i] = 1.0
            predictions.append(pred_vector)

        # Cleanup injected scores
        if self.model is not None and self.dataset is not None and test_indices is not None:
            for idx in test_indices:
                for claim in self.dataset.raw_data['data'][idx]['claims']:
                    claim.pop('learned_scores', None)

        return predictions


class HardBaseline(BaseMethod):
    """
    Hard coherent factuality baseline.

    Uses graph structure + hard conformal prediction with beta_mix.
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.C = config.get('C', 6.0)
        self.beta_mix = config['hyperparams'].get('beta_mix', 0.5)
        self.feature_index = 0  # Use first feature (frequency-score)
        self.dataset = config.get('dataset', None)
        self.cal_indices = None
        self.noise_dict = None

    def calibrate(self, X_cal, Y_cal, noise_cal, alpha, X_train=None, Y_train=None, noise_train=None, cal_indices=None, noise_dict=None):
        """Calibrate using hard conformal prediction with beta_mix."""
        from src.hard.non_conformity import r_score

        # Store cal_indices and noise_dict for prediction
        self.cal_indices = cal_indices
        self.noise_dict = noise_dict

        # Hard baseline doesn't need training - uses frequency-score directly
        # Use hard conformal prediction with beta_mix
        if self.dataset is not None and cal_indices is not None:
            cal_scores = []
            for idx in cal_indices:
                question = self.dataset.raw_data['data'][idx]
                score = r_score(question, self.noise_dict[idx], 'simult', 'graph',
                              beta=self.beta_mix, score_type='frequency-score')
                cal_scores.append(score)

            threshold = compute_quantile(cal_scores, alpha)
            return threshold
        else:
            # Fallback to simple threshold method if dataset not provided
            risks = []
            for x, noise in zip(X_cal, noise_cal):
                features = x['features']
                score = features[:, self.feature_index] if features.ndim > 1 else features
                risk = self.C - score + noise
                risks.append(risk)

            all_risks = torch.cat([r.flatten() for r in risks])
            risks_list = all_risks.tolist()
            threshold = compute_quantile(risks_list, alpha)
            return threshold

    def predict(self, X_test, noise_test, threshold, alpha=None, test_indices=None):
        """Predict using hard conformal prediction (highest_risk_graph)."""
        from src.hard.non_conformity import highest_risk_graph
        import torch

        if self.dataset is not None and test_indices is not None:
            # Use hard conformal prediction (highest_risk_graph)
            predictions = []
            for idx in test_indices:
                question = self.dataset.raw_data['data'][idx]
                U_filt = highest_risk_graph(threshold, self.noise_dict[idx], question, 'simult',
                                           beta=self.beta_mix, score_type='frequency-score')

                # Convert to binary prediction vector
                pred_vector = torch.zeros(len(question['claims']))
                for i in U_filt:
                    pred_vector[i] = 1.0
                predictions.append(pred_vector)

            return predictions
        else:
            # Fallback
            predictions = []
            for x, noise in zip(X_test, noise_test):
                features = x['features']
                score = features[:, self.feature_index] if features.ndim > 1 else features
                risk = self.C - score + noise

                # Keep claims with risk <= threshold
                keep_probs = (risk <= threshold).float()
                predictions.append(keep_probs)

            return predictions


class HashimotoIndependent(BaseMethod):
    """
    Independent factuality (Hashimoto method).

    Treats each claim independently, no graph structure.
    Uses split conformal prediction per claim.
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.C = config.get('C', 6.0)
        self.feature_index = 0  # Single feature

    def calibrate(self, X_cal, Y_cal, noise_cal, alpha, X_train=None, Y_train=None, noise_train=None, cal_indices=None, noise_dict=None):
        """
        Calibrate using independent method (simple threshold per claim).

        Independent method doesn't use graph structure - just calibrates per claim.
        """
        # Independent method doesn't need training - uses frequency-score
        # Flatten all risks and compute quantile
        risks = []
        for x, noise in zip(X_cal, noise_cal):
            features = x['features']
            score = features[:, self.feature_index] if features.ndim > 1 else features
            risk = self.C - score + noise
            risks.append(risk)

        all_risks = torch.cat([r.flatten() for r in risks])
        risks_list = all_risks.tolist()
        threshold = compute_quantile(risks_list, alpha)

        return threshold

    def predict(self, X_test, noise_test, threshold, alpha=None, test_indices=None):
        """Predict independently per claim."""
        predictions = []

        for x, noise in zip(X_test, noise_test):
            features = x['features']
            score = features[:, self.feature_index] if features.ndim > 1 else features
            risk = self.C - score + noise

            # Keep claims with risk <= threshold
            keep_probs = (risk <= threshold).float()
            predictions.append(keep_probs)

        return predictions


class BoostedIndependent(BaseMethod):
    """
    Boosted independent factuality.

    Learns optimal linear combination of features via gradient descent.
    Trains once per alpha and caches weights.
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.C = config.get('C', 6.0)
        self.n_features = len(config['feature_cols'])
        self.learning_config = config.get('learning_config', {})

        # Cache learned weights per alpha
        self.alpha_weights = {}  # {alpha: beta_tensor}
        self.weights_trained = set()  # Track which alphas have been trained

    def _hyperparameter_tuning(self, X_all, Y_all, noise_all, alpha):
        """
        Perform cross-validation hyperparameter tuning (similar to hyperparameter_optimization.py).

        Uses n_folds with generate_kfold_splits to match the learned model's k-fold methodology.

        Returns:
            best_beta: Optimal weights trained on all data with best hyperparameters
            best_config: Best hyperparameter configuration
        """
        # Hyperparameter search space
        lr_options = self.learning_config.get('lr_search', [0.005, 0.01, 0.02])
        n_iters_options = self.learning_config.get('n_iters_search', [100, 200])
        l2_reg_options = self.learning_config.get('l2_reg_search', [0.0, 0.001])
        n_folds = self.learning_config.get('n_tuning_folds', 5)

        print(f"  [Alpha {alpha}] Hyperparameter tuning: {len(lr_options) * len(n_iters_options) * len(l2_reg_options)} configs × {n_folds} folds")

        best_retention = -float('inf')
        best_config = None

        # Use the same k-fold splits as the learned model for consistency
        n_total = len(X_all)
        fold_assignments = generate_kfold_splits(n_total, n_folds=n_folds, seed=42)

        # Try each hyperparameter configuration
        from tqdm import tqdm
        total_configs = len(lr_options) * len(n_iters_options) * len(l2_reg_options)
        config_idx = 0

        for lr in lr_options:
            for n_iters in n_iters_options:
                for l2_reg in l2_reg_options:
                    config_idx += 1
                    # Cross-validation evaluation
                    fold_retentions = []

                    for fold_idx in tqdm(range(n_folds),
                                        desc=f"  [Alpha {alpha}] Config {config_idx}/{total_configs} (lr={lr}, n_iters={n_iters}, l2={l2_reg})",
                                        leave=False):
                        # Use generate_kfold_splits for consistent k-fold partitioning
                        val_indices = fold_assignments[fold_idx]
                        train_indices = [i for i in range(n_total) if i not in val_indices]

                        X_train = [X_all[i] for i in train_indices]
                        Y_train = [Y_all[i] for i in train_indices]
                        noise_train = [noise_all[i] for i in train_indices]

                        X_val = [X_all[i] for i in val_indices]
                        Y_val = [Y_all[i] for i in val_indices]
                        noise_val = [noise_all[i] for i in val_indices]

                        # Train with this config
                        beta = self._train_weights(X_train, Y_train, noise_train, lr, n_iters, l2_reg, verbose=False)

                        # Evaluate retention on validation set
                        retention = self._compute_retention(X_val, Y_val, noise_val, beta, alpha)
                        fold_retentions.append(retention)

                    # Average retention across folds
                    avg_retention = np.mean(fold_retentions)

                    if avg_retention > best_retention:
                        best_retention = avg_retention
                        best_config = {'lr': lr, 'n_iters': n_iters, 'l2_reg': l2_reg}

        print(f"  [Alpha {alpha}] Best config: lr={best_config['lr']}, n_iters={best_config['n_iters']}, l2_reg={best_config['l2_reg']}, avg_retention={best_retention:.4f}")

        # Train final model on all data with best hyperparameters
        best_beta = self._train_weights(X_all, Y_all, noise_all,
                                       best_config['lr'], best_config['n_iters'], best_config['l2_reg'],
                                       verbose=True)

        return best_beta, best_config

    def _train_weights(self, X_data, Y_data, noise_data, lr, n_iters, l2_reg, verbose=False):
        """Train feature weights with given hyperparameters using gradient descent."""
        n_features = self.n_features

        # Initialize weights
        beta = torch.ones(n_features) / n_features
        beta.requires_grad = True

        optimizer = torch.optim.Adam([beta], lr=lr)

        for iteration in range(n_iters):
            optimizer.zero_grad()

            # Compute retention objective
            total_retention = 0.0
            n_examples = len(X_data)

            for x in X_data:
                features = x['features']
                if not isinstance(features, torch.Tensor):
                    features = torch.tensor(features, dtype=torch.float32)

                # Weighted score
                weighted_score = torch.abs(features @ beta)

                # Soft retention (sigmoid for differentiability)
                retention_prob = torch.sigmoid(weighted_score - self.C)
                total_retention += retention_prob.mean() / n_examples

            # Loss: maximize retention, add L2 regularization
            loss = -total_retention + l2_reg * (beta ** 2).sum()
            loss.backward()
            optimizer.step()

            if verbose and iteration % 50 == 0:
                print(f"    Iteration {iteration}: retention={total_retention.item():.4f}")

        if verbose:
            print(f"    Final beta: {beta.detach().numpy()[:10]}... (showing first 10)")

        return beta.detach()

    def _compute_retention(self, X_data, Y_data, noise_data, beta, alpha):
        """Compute average retention rate on data with given weights and alpha."""
        # Compute threshold on this data
        risks = []
        for x, noise in zip(X_data, noise_data):
            features = x['features']
            if not isinstance(features, torch.Tensor):
                features = torch.tensor(features, dtype=torch.float32)

            weighted_score = torch.abs(features @ beta)
            risk = self.C - weighted_score + noise
            risks.append(risk)

        all_risks = torch.cat([r.flatten() for r in risks])
        threshold = compute_quantile(all_risks.tolist(), alpha)

        # Compute retention
        total_retention = 0.0
        for x in X_data:
            features = x['features']
            if not isinstance(features, torch.Tensor):
                features = torch.tensor(features, dtype=torch.float32)

            weighted_score = torch.abs(features @ beta)
            risk = self.C - weighted_score
            kept = (risk <= threshold).float()
            total_retention += kept.mean().item()

        return total_retention / len(X_data) if len(X_data) > 0 else 0.0

    def calibrate(self, X_cal, Y_cal, noise_cal, alpha, X_train=None, Y_train=None, noise_train=None, cal_indices=None, noise_dict=None):
        """
        Calibrate with learned weights.

        Key change: Train once per alpha (not per trial) and cache the weights.
        """
        # Check if we've already trained for this alpha
        if alpha not in self.weights_trained:
            print(f"\n  [BoostedIndependent] Training weights for alpha={alpha} (first time)")

            # Combine training + calibration data for hyperparameter tuning
            # This matches what we do in hyperparameter_optimization.py
            if X_train is not None and len(X_train) > 0:
                X_all = X_train + X_cal
                Y_all = Y_train + Y_cal
                noise_all = noise_train + noise_cal
            else:
                X_all = X_cal
                Y_all = Y_cal
                noise_all = noise_cal

            # Perform hyperparameter tuning with cross-validation
            enable_tuning = self.learning_config.get('enable_tuning', True)

            if enable_tuning:
                best_beta, best_config = self._hyperparameter_tuning(X_all, Y_all, noise_all, alpha)
                self.alpha_weights[alpha] = best_beta
            else:
                # Use fixed hyperparameters (no tuning)
                lr = self.learning_config.get('lr', 0.01)
                n_iters = self.learning_config.get('n_iters', 100)
                l2_reg = self.learning_config.get('l2_reg', 0.0)
                beta = self._train_weights(X_all, Y_all, noise_all, lr, n_iters, l2_reg, verbose=True)
                self.alpha_weights[alpha] = beta

            self.weights_trained.add(alpha)
            print(f"  [BoostedIndependent] ✓ Weights cached for alpha={alpha}\n")
        else:
            print(f"  [BoostedIndependent] Using cached weights for alpha={alpha}")

        # Use cached weights for this alpha
        beta = self.alpha_weights[alpha]

        # Compute threshold on calibration set
        risks = []
        for x, noise in zip(X_cal, noise_cal):
            features = x['features']
            if not isinstance(features, torch.Tensor):
                features = torch.tensor(features, dtype=torch.float32)

            # Weighted score with learned beta
            weighted_score = torch.abs(features @ beta)
            risk = self.C - weighted_score + noise
            risks.append(risk)

        all_risks = torch.cat([r.flatten() for r in risks])
        risks_list = all_risks.tolist()
        threshold = compute_quantile(risks_list, alpha)

        return threshold

    def predict(self, X_test, noise_test, threshold, alpha=None, test_indices=None):
        """Predict using cached learned weights for this alpha."""
        if alpha is None or alpha not in self.alpha_weights:
            raise ValueError(f"No trained weights for alpha={alpha}. Must call calibrate first.")

        beta = self.alpha_weights[alpha]
        predictions = []

        for x, noise in zip(X_test, noise_test):
            features = x['features']
            if not isinstance(features, torch.Tensor):
                features = torch.tensor(features, dtype=torch.float32)

            # Weighted score
            weighted_score = torch.abs(features @ beta)
            risk = self.C - weighted_score + noise

            # Keep claims with risk <= threshold
            keep_probs = (risk <= threshold).float()
            predictions.append(keep_probs)

        return predictions


class XGBoostAccuracy(BaseMethod):
    """
    XGBoost trained for accuracy with post-hoc conformal calibration.

    This baseline trains XGBoost to predict claim accuracy (binary classification),
    then uses the predicted probabilities as nonconformity scores for conformal prediction.
    This represents what a reviewer might expect: "why not just use XGBoost?"
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.C = config.get('C', 6.0)
        self.n_features = len(config['feature_cols'])
        self.learning_config = config.get('learning_config', {})
        self.dataset = config.get('dataset', None)

        # Cache trained models per alpha (though XGBoost doesn't depend on alpha for training)
        self.model = None
        self.is_trained = False

    def _prepare_training_data(self, X_data, Y_data):
        """Prepare flat training data for XGBoost (one row per claim)."""
        X_flat = []
        y_flat = []

        for x, y in zip(X_data, Y_data):
            features = x['features']
            if isinstance(features, torch.Tensor):
                features = features.numpy()

            labels = y
            if isinstance(labels, torch.Tensor):
                labels = labels.numpy()

            # Each claim becomes a row
            for i in range(len(features)):
                X_flat.append(features[i])
                y_flat.append(labels[i])

        return np.array(X_flat), np.array(y_flat)

    def _train_xgboost(self, X_train, Y_train):
        """Train XGBoost classifier for binary accuracy prediction."""
        try:
            import xgboost as xgb
        except ImportError:
            raise ImportError("XGBoost not installed. Please run: pip install xgboost")

        # Prepare flat training data
        X_flat, y_flat = self._prepare_training_data(X_train, Y_train)

        # XGBoost hyperparameters
        xgb_params = self.learning_config.get('xgb_params', {
            'max_depth': 4,
            'learning_rate': 0.1,
            'n_estimators': 100,
            'objective': 'binary:logistic',
            'eval_metric': 'logloss',
            'random_state': 42,
            'n_jobs': -1
        })

        # Train XGBoost
        self.model = xgb.XGBClassifier(**xgb_params)
        self.model.fit(X_flat, y_flat)
        self.is_trained = True

        # Report training accuracy
        train_preds = self.model.predict(X_flat)
        train_acc = np.mean(train_preds == y_flat)
        print(f"    XGBoost training accuracy: {train_acc:.4f} ({len(y_flat)} claims)")

    def _get_scores(self, X_data):
        """Get XGBoost predicted probabilities as scores."""
        all_scores = []

        for x in X_data:
            features = x['features']
            if isinstance(features, torch.Tensor):
                features = features.numpy()

            # Get predicted probability of being correct (class 1)
            probs = self.model.predict_proba(features)[:, 1]
            all_scores.append(torch.tensor(probs, dtype=torch.float32))

        return all_scores

    def calibrate(self, X_cal, Y_cal, noise_cal, alpha, X_train=None, Y_train=None, noise_train=None, cal_indices=None, noise_dict=None):
        """
        Train XGBoost on training data, then calibrate threshold on calibration data.
        """
        # Train XGBoost if not already trained
        if not self.is_trained:
            if X_train is not None and len(X_train) > 0:
                print(f"\n  [XGBoostAccuracy] Training XGBoost on {len(X_train)} examples")
                self._train_xgboost(X_train, Y_train)
            else:
                # If no training data, train on calibration data (less ideal)
                print(f"\n  [XGBoostAccuracy] Training XGBoost on calibration data ({len(X_cal)} examples)")
                self._train_xgboost(X_cal, Y_cal)

        # Get XGBoost scores on calibration data
        cal_scores = self._get_scores(X_cal)

        # Compute nonconformity scores: lower probability = higher risk
        # risk = C - score + noise (same formula as other methods)
        risks = []
        for scores, noise in zip(cal_scores, noise_cal):
            risk = self.C - scores + noise
            risks.append(risk)

        all_risks = torch.cat([r.flatten() for r in risks])
        threshold = compute_quantile(all_risks.tolist(), alpha)

        return threshold

    def predict(self, X_test, noise_test, threshold, alpha=None, test_indices=None):
        """Predict using XGBoost scores and calibrated threshold."""
        if not self.is_trained:
            raise ValueError("XGBoost model not trained. Must call calibrate first.")

        # Get XGBoost scores on test data
        test_scores = self._get_scores(X_test)

        predictions = []
        for scores, noise in zip(test_scores, noise_test):
            # Compute risk
            risk = self.C - scores + noise

            # Keep claims with risk <= threshold
            keep_probs = (risk <= threshold).float()
            predictions.append(keep_probs)

        return predictions
