# POSTOP CATE Analysis - Adapted for Licorice Gargle Dataset

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()  # Ensure immediate write to file

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class MedicalCATEAllocator:
    """Medical CATE allocation algorithm with fixed gamma=0.5 and updated heavy interval threshold."""

    def __init__(self, epsilon=0.1, gamma=0.5, delta=0.05, heavy_multiplier=1.6, random_seed=42):
        self.epsilon = epsilon
        self.gamma = gamma
        self.rho = gamma * np.sqrt(epsilon)
        self.delta = delta
        self.heavy_multiplier = heavy_multiplier  # New parameter for heavy interval threshold
        self.random_seed = random_seed
        np.random.seed(random_seed)

        print(f"Medical CATE Allocation Algorithm")
        print(f"ε = {epsilon}")
        print(f"√ε = {np.sqrt(epsilon):.6f}")
        print(f"γ = {gamma}")
        print(f"ρ = γ√ε = {self.rho:.6f}")
        print(f"Heavy multiplier = {heavy_multiplier}x")
        print(f"δ = {delta}")
        print("="*60)

    def process_medical_data(self, df, outcome_col='postop4hour_throatpain', treatment_col='treat'):
        """Process medical dataset for analysis."""
        print(f"Processing medical data with {len(df)} patients")
        print(f"Available columns: {list(df.columns)}")

        df_processed = df.copy()

        # Check for required columns
        if treatment_col not in df_processed.columns:
            raise ValueError(f"Missing required treatment column: {treatment_col}")
        if outcome_col not in df_processed.columns:
            raise ValueError(f"Missing required outcome column: {outcome_col}")

        # Set up treatment and outcome
        df_processed['treatment'] = df_processed[treatment_col]
        df_processed['outcome'] = df_processed[outcome_col]

        # Create baseline risk using preoperative pain if available
        if 'preop_pain' in df_processed.columns:
            df_processed['baseline_risk'] = df_processed['preop_pain']
        else:
            df_processed['baseline_risk'] = 0  # Default if no baseline

        # Clean data
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} patients")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome (4-hour throat pain) statistics: mean={df_processed['outcome'].mean():.2f}, std={df_processed['outcome'].std():.2f}")

        if 'baseline_risk' in df_processed.columns:
            print(f"Baseline (preop pain) stats: mean={df_processed['baseline_risk'].mean():.2f}, std={df_processed['baseline_risk'].std():.2f}")

        return df_processed

    def create_demographics_groups(self, df, min_size=6):
        """Create groups by key demographic characteristics in medical data."""
        print(f"Creating medical demographics groups")

        # Key medical demographic variables
        demo_features = ['preop_gender', 'preop_smoking', 'preop_asa']

        # Check which features are available
        available_features = [col for col in demo_features if col in df.columns]

        if not available_features:
            print("No demographic variables found")
            return []

        print(f"Using demographic features: {available_features}")

        # Limit to top 3 features to avoid too many combinations
        if len(available_features) > 3:
            available_features = available_features[:3]

        # Remove rows with missing values in these features
        df_clean = df.dropna(subset=available_features)
        print(f"After removing missing values: {len(df_clean)}/{len(df)} patients")

        if len(df_clean) == 0:
            return []

        # Get unique combinations
        groups = []
        unique_combinations = df_clean[available_features].drop_duplicates()
        print(f"Found {len(unique_combinations)} unique demographic combinations")

        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_age_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on age brackets."""
        print(f"Creating age groups (target: {n_groups})")

        if 'preop_age' not in df.columns:
            print("No age variable found")
            return []

        # Create age-based groups
        age = df['preop_age'].fillna(df['preop_age'].median())

        # Create age brackets
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(age, percentiles)
        bins = np.digitize(age, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_group_{i}',
                    'indices': indices,
                    'type': 'age'
                })

        print(f"Created {len(groups)} age groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_bmi_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on BMI categories."""
        print(f"Creating BMI groups (target: {n_groups})")

        if 'preop_calcbmi' not in df.columns:
            print("No BMI variable found")
            return []

        # Create BMI-based groups
        bmi = df['preop_calcbmi'].fillna(df['preop_calcbmi'].median())

        # Create BMI brackets
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(bmi, percentiles)
        bins = np.digitize(bmi, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'bmi_group_{i}',
                    'indices': indices,
                    'type': 'bmi'
                })

        print(f"Created {len(groups)} BMI groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_mallampati_groups(self, df, min_size=6):
        """Create groups based on Mallampati score (airway assessment)."""
        print(f"Creating Mallampati score groups")

        if 'preop_mallampati' not in df.columns:
            print("No Mallampati score variable found")
            return []

        groups = []
        for mallampati_score in df['preop_mallampati'].unique():
            if pd.isna(mallampati_score):
                continue

            indices = df[df['preop_mallampati'] == mallampati_score].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'mallampati_class_{mallampati_score}',
                    'indices': indices,
                    'type': 'mallampati'
                })

        print(f"Created {len(groups)} Mallampati groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_preop_pain_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on preoperative pain levels."""
        print(f"Creating preoperative pain groups (target: {n_groups})")

        if 'preop_pain' not in df.columns:
            print("No preoperative pain data available")
            return []

        # Create pain-based groups
        pain = df['preop_pain'].fillna(0)  # Fill NaN with 0 for no pain

        # Create pain level brackets including zero pain
        if (pain == 0).mean() > 0.3:  # If >30% have zero pain, create separate zero group
            # Create one group for zero pain patients
            zero_pain = df.index[pain == 0].tolist()
            groups = []
            if len(zero_pain) >= min_size:
                groups.append({
                    'id': 'no_preop_pain',
                    'indices': zero_pain,
                    'type': 'preop_pain'
                })

            # Create groups for patients with pain
            positive_pain = pain[pain > 0]
            if len(positive_pain) > 0:
                percentiles = np.linspace(0, 100, n_groups)
                cuts = np.percentile(positive_pain, percentiles)

                for i in range(len(cuts) - 1):
                    mask = (pain > cuts[i]) & (pain <= cuts[i + 1])
                    indices = df.index[mask].tolist()
                    if len(indices) >= min_size:
                        groups.append({
                            'id': f'preop_pain_level_{i}',
                            'indices': indices,
                            'type': 'preop_pain'
                        })
        else:
            # Standard percentile groups
            percentiles = np.linspace(0, 100, n_groups + 1)
            cuts = np.percentile(pain, percentiles)
            bins = np.digitize(pain, cuts) - 1

            groups = []
            for i in range(n_groups):
                indices = df.index[bins == i].tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'preop_pain_level_{i}',
                        'indices': indices,
                        'type': 'preop_pain'
                    })

        print(f"Created {len(groups)} preoperative pain groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_gender_smoking_groups(self, df, min_size=6):
        """Create groups based on gender and smoking status combinations."""
        print(f"Creating gender-smoking groups")

        # Create gender-smoking categories
        def get_gender_smoking(row):
            gender = row.get('preop_gender', 'unknown')
            smoking = row.get('preop_smoking', 'unknown')
            return f"gender_{gender}_smoking_{smoking}"

        if 'preop_gender' in df.columns and 'preop_smoking' in df.columns:
            df['gender_smoking'] = df.apply(get_gender_smoking, axis=1)

            groups = []
            for category in df['gender_smoking'].unique():
                indices = df[df['gender_smoking'] == category].index.tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'{category}',
                        'indices': indices,
                        'type': 'gender_smoking'
                    })

            print(f"Created {len(groups)} gender-smoking groups")
            balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
            return balanced_groups
        else:
            print("No gender or smoking variables found")
            return []

    def create_asa_groups(self, df, min_size=6):
        """Create groups based on ASA (American Society of Anesthesiologists) physical status."""
        print(f"Creating ASA physical status groups")

        if 'preop_asa' not in df.columns:
            print("No ASA physical status data")
            return []

        groups = []
        for asa_status in df['preop_asa'].unique():
            if pd.isna(asa_status):
                continue

            indices = df[df['preop_asa'] == asa_status].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'asa_class_{asa_status}',
                    'indices': indices,
                    'type': 'asa_status'
                })

        print(f"Created {len(groups)} ASA status groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_surgery_size_groups(self, df, min_size=6):
        """Create groups based on surgery size."""
        print(f"Creating surgery size groups")

        if 'intraop_surgerysize' not in df.columns:
            print("No surgery size data available")
            return []

        groups = []
        for surgery_size in df['intraop_surgerysize'].unique():
            if pd.isna(surgery_size):
                continue

            indices = df[df['intraop_surgerysize'] == surgery_size].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'surgery_size_{surgery_size}',
                    'indices': indices,
                    'type': 'surgery_size'
                })

        print(f"Created {len(groups)} surgery size groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_causal_forest_groups(self, df, n_groups=30, min_size=6):
        """Create groups using clustering on baseline covariates only (no outcome data)."""
        print(f"Creating covariate-based forest groups (target: {n_groups})")

        # Use medical covariates ONLY - no outcome variables
        feature_cols = ['preop_age', 'preop_calcbmi', 'preop_gender', 'preop_asa',
                       'preop_mallampati', 'preop_smoking', 'preop_pain', 'intraop_surgerysize']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for covariate clustering")
            return []

        X = df[available_features].copy()

        # Handle missing values and encode categorical variables
        for col in X.columns:
            if X[col].dtype == 'object':
                # Encode categorical variables
                le = LabelEncoder()
                X[col] = X[col].fillna('missing')
                X[col] = le.fit_transform(X[col])
            else:
                if X[col].isna().any():
                    X[col] = X[col].fillna(X[col].median())

        # Cluster based ONLY on baseline covariates (no outcome-based predictions)
        cluster_features = StandardScaler().fit_transform(X.values)
        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'covariate_cluster_{i}',
                    'indices': indices,
                    'type': 'covariate_cluster'
                })

        print(f"Created {len(groups)} covariate-based groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score strata."""
        print(f"Creating propensity score groups (target: {n_groups})")

        # Use medical covariates
        feature_cols = ['preop_age', 'preop_calcbmi', 'preop_gender', 'preop_asa',
                       'preop_mallampati', 'preop_smoking', 'preop_pain', 'intraop_surgerysize']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for propensity scoring")
            return []

        X = df[available_features].copy()

        # Handle missing values and encode categorical variables
        for col in X.columns:
            if X[col].dtype == 'object':
                # Encode categorical variables
                le = LabelEncoder()
                X[col] = X[col].fillna('missing')
                X[col] = le.fit_transform(X[col])
            else:
                if X[col].isna().any():
                    X[col] = X[col].fillna(X[col].median())

        # Get propensity scores
        try:
            prop_scores = cross_val_predict(
                LogisticRegression(random_state=self.random_seed, max_iter=1000),
                X, df['treatment'], method='predict_proba', cv=5
            )[:, 1]
        except Exception as e:
            print(f"Error computing propensity scores: {e}")
            return []

        # Create strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure treatment balance and compute group CATE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            # For this dataset, negative CATE means treatment reduces throat pain (beneficial)
            # We reverse sign to maintain consistency with plots (higher = more beneficial)
            cate = -(treated_outcomes.mean() - control_outcomes.mean())  # Reversed sign

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1]."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def plot_cate_distribution(self, groups, title_suffix=""):
        """Plot CATE distribution."""
        original_cates = [g['cate'] for g in groups]
        normalized_cates = [g['normalized_cate'] for g in groups]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

        ax1.hist(original_cates, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.set_xlabel('Original CATE (pain reduction effect)')
        ax1.set_ylabel('Frequency')
        ax1.set_title(f'Original CATE Distribution{title_suffix}')
        ax1.grid(True, alpha=0.3)

        ax2.hist(normalized_cates, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax2.set_xlabel('Normalized CATE (τ)')
        ax2.set_ylabel('Frequency')
        ax2.set_title(f'Normalized CATE Distribution{title_suffix}')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def estimate_tau(self, true_tau, accuracy):
        """Estimate tau using Hoeffding's inequality with Bernoulli samples."""
        sample_size = int(np.ceil(np.log(2/self.delta) / (2 * accuracy**2)))
        samples = np.random.binomial(1, true_tau, sample_size)
        return np.mean(samples), sample_size

    def run_single_trial(self, groups, epsilon_val, trial_seed):
        """Run allocation algorithm for single trial with fixed gamma."""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])
        rho = self.gamma * np.sqrt(epsilon_val)  # Use fixed gamma

        # Estimate all tau values using rho accuracy
        tau_estimates_rho = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, rho)
            tau_estimates_rho.append(estimate)
        tau_estimates_rho = np.array(tau_estimates_rho)

        # Also estimate using epsilon accuracy for comparison
        tau_estimates_eps = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, epsilon_val)
            tau_estimates_eps.append(estimate)
        tau_estimates_eps = np.array(tau_estimates_eps)

        results = []

        for K in range(1, n_groups):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_value = np.sum(tau_true[optimal_indices])

            rho_indices = np.argsort(tau_estimates_rho)[-K:]
            rho_value = np.sum(tau_true[rho_indices])

            eps_indices = np.argsort(tau_estimates_eps)[-K:]
            eps_value = np.sum(tau_true[eps_indices])

            rho_ratio = rho_value / optimal_value if optimal_value > 0 else 0
            eps_ratio = eps_value / optimal_value if optimal_value > 0 else 0
            rho_success = rho_ratio >= (1 - epsilon_val)
            eps_success = eps_ratio >= (1 - epsilon_val)

            tau_k_est = tau_estimates_rho[rho_indices[0]]
            a2_lower = tau_k_est
            a2_upper = tau_k_est + 2 * rho
            units_in_a2 = np.sum((tau_estimates_rho >= a2_lower) & (tau_estimates_rho <= a2_upper))
            expected_a2 = 2 * rho * n_groups
            # Updated heavy interval detection with 1.6x multiplier
            is_heavy = units_in_a2 > self.heavy_multiplier * expected_a2

            results.append({
                'K': K,
                'optimal_value': optimal_value,
                'rho_value': rho_value,
                'eps_value': eps_value,
                'rho_ratio': rho_ratio,
                'eps_ratio': eps_ratio,
                'rho_success': rho_success,
                'eps_success': eps_success,
                'is_heavy': is_heavy,
                'tau_k_est': tau_k_est,
                'units_in_a2': units_in_a2
            })

        return results, tau_estimates_rho

    def find_recovery_units(self, K, tau_true, tau_estimates, epsilon_val):
        """Find minimum units needed to achieve 1-epsilon performance."""
        n_groups = len(tau_true)

        # Original allocation (using rho estimates)
        rho_indices = np.argsort(tau_estimates)[-K:]
        optimal_value = np.sum(tau_true[np.argsort(tau_true)[-K:]])

        # Remaining candidates (sorted by estimate, best first)
        remaining_indices = np.argsort(tau_estimates)[:-K][::-1]

        # Test adding 1 to 10 additional units
        for extra in range(1, 11):
            if extra > len(remaining_indices):
                break

            expanded_indices = np.concatenate([rho_indices, remaining_indices[:extra]])
            expanded_value = np.sum(tau_true[expanded_indices])

            if expanded_value / optimal_value >= (1 - epsilon_val):
                return extra

        return None  # Need more than 10 units

    def find_closest_working_budget(self, failed_K, trial_results):
        """Find closest budget that works for a failed budget."""
        working_budgets = [r['K'] for r in trial_results if r['rho_success']]

        if not working_budgets:
            return None, None

        # Distance to any working budget (either direction)
        distances_any = [abs(K - failed_K) for K in working_budgets]
        min_distance_any = min(distances_any)

        # Distance to smaller working budget (underspending)
        smaller_working = [K for K in working_budgets if K < failed_K]
        if smaller_working:
            min_distance_smaller = failed_K - max(smaller_working)
        else:
            min_distance_smaller = None

        return min_distance_any, min_distance_smaller

    def analyze_method(self, groups, epsilon_val, n_trials=30):
        """Analyze single method with fixed gamma and updated heavy threshold."""
        print(f"\nAnalyzing {len(groups)} groups with ε={epsilon_val}, γ={self.gamma}")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        trial_data = []

        for trial in range(n_trials):
            print(f"Trial {trial + 1}/{n_trials}...")

            # Run single trial
            trial_results, tau_estimates = self.run_single_trial(groups, epsilon_val, trial)

            # Analyze failures
            failed_results = [r for r in trial_results if not r['rho_success']]
            failed_budgets = [r['K'] for r in failed_results]

            # Check which failed budgets are heavy with true tau values
            failed_heavy_estimated = []
            failed_heavy_true = []
            rho = self.gamma * np.sqrt(epsilon_val)

            for failed_result in failed_results:
                K = failed_result['K']
                # Heavy with estimated values (already computed)
                failed_heavy_estimated.append(failed_result['is_heavy'])

                # Check heavy with true tau values
                tau_k_true = tau_true[np.argsort(tau_true)[-K:]][0]  # True smallest in top-K
                a2_lower_true = tau_k_true
                a2_upper_true = tau_k_true + 2 * rho
                units_in_a2_true = np.sum((tau_true >= a2_lower_true) & (tau_true <= a2_upper_true))
                expected_a2_true = 2 * rho * n_groups
                is_heavy_true = units_in_a2_true > self.heavy_multiplier * expected_a2_true
                failed_heavy_true.append(is_heavy_true)

            # Print trial summary
            print(f"  Failed budgets: {failed_budgets}")

            # Print heavy vectors
            if len(failed_budgets) > 0:
                estimated_clean = [bool(x) for x in failed_heavy_estimated]
                true_clean = [bool(x) for x in failed_heavy_true]
                print(f"  HEAVY INTERVALS - Estimated: {estimated_clean}")
                print(f"  HEAVY INTERVALS - True τ_K:   {true_clean}")

            # Count total heavy intervals and failed budgets in heavy intervals
            total_heavy = sum(r['is_heavy'] for r in trial_results)
            failed_heavy = sum(r['is_heavy'] for r in failed_results)

            # Recovery analysis
            recovery_units = []
            distances_to_working_any = []
            distances_to_working_smaller = []

            for failed_result in failed_results:
                K = failed_result['K']

                # Find recovery units needed
                recovery = self.find_recovery_units(K, tau_true, tau_estimates, epsilon_val)
                if recovery is not None:
                    recovery_units.append(recovery)

                # Find distances to closest working budgets
                distance_any, distance_smaller = self.find_closest_working_budget(K, trial_results)
                if distance_any is not None:
                    distances_to_working_any.append(distance_any)
                if distance_smaller is not None:
                    distances_to_working_smaller.append(distance_smaller)

            trial_info = {
                'trial': trial,
                'failed_budgets': failed_budgets,
                'num_failures': len(failed_results),
                'total_heavy': total_heavy,
                'failed_heavy': failed_heavy,
                'failed_heavy_estimated': failed_heavy_estimated,
                'failed_heavy_true': failed_heavy_true,
                'recovery_units': recovery_units,
                'distances_to_working_any': distances_to_working_any,
                'distances_to_working_smaller': distances_to_working_smaller
            }

            trial_data.append(trial_info)

            print(f"  Failures: {len(failed_results)}, Total heavy: {total_heavy}, Failed heavy: {failed_heavy}")
            if recovery_units:
                print(f"  Recovery units: μ={np.mean(recovery_units):.1f}, med={np.median(recovery_units):.0f}, max={np.max(recovery_units)}")
            if distances_to_working_any:
                print(f"  Distance any: μ={np.mean(distances_to_working_any):.1f}, med={np.median(distances_to_working_any):.0f}, max={np.max(distances_to_working_any)}")
            if distances_to_working_smaller:
                print(f"  Distance smaller: μ={np.mean(distances_to_working_smaller):.1f}, med={np.median(distances_to_working_smaller):.0f}, max={np.max(distances_to_working_smaller)}")
            else:
                print(f"  Distance smaller: No smaller working budgets found")

        return trial_data

    def print_method_summary(self, method_name, trial_data, n_groups, epsilon_val):
        """Print summary statistics for a method."""
        budget_10pct_threshold = max(1, int(0.1 * n_groups))

        print(f"\n{'='*100}")
        print(f"SUMMARY - {method_name} - ε={epsilon_val} - {n_groups} GROUPS")
        print("="*100)
        print(f"{'Fail μ':<7} {'Fail σ':<7} {'FailR% μ':<9} {'FailR% σ':<9} {'TotHvy':<8} {'FailHvy':<9} {'Rec μ':<7} {'Rec med':<8} {'Rec max':<8} {'DAny μ':<8} {'DAny σ':<10} {'DAny max':<10} {'DSmall μ':<10} {'DSmall σ':<12} {'DSmall max':<12}")
        print("-"*120)

        # Aggregate statistics across all trials - ALL BUDGETS
        all_failures = [t['num_failures'] for t in trial_data]
        all_total_heavy = [t['total_heavy'] for t in trial_data]
        all_failed_heavy = [t['failed_heavy'] for t in trial_data]
        all_recovery = []
        all_distances_any = []
        all_distances_smaller = []

        for t in trial_data:
            all_recovery.extend(t['recovery_units'])
            all_distances_any.extend(t['distances_to_working_any'])
            all_distances_smaller.extend(t['distances_to_working_smaller'])

        avg_failures = np.mean(all_failures)
        std_failures = np.std(all_failures)
        avg_failure_rate = avg_failures / (n_groups - 1) * 100
        std_failure_rate = std_failures / (n_groups - 1) * 100
        avg_total_heavy = np.mean(all_total_heavy)
        avg_failed_heavy = np.mean(all_failed_heavy)

        # Recovery statistics
        if all_recovery:
            recovery_mean = np.mean(all_recovery)
            recovery_med = np.median(all_recovery)
            recovery_max = np.max(all_recovery)
        else:
            recovery_mean = recovery_med = recovery_max = np.nan

        # Distance statistics - any direction
        if all_distances_any:
            distance_any_mean = np.mean(all_distances_any)
            distance_any_std = np.std(all_distances_any)
            distance_any_max = np.max(all_distances_any)
        else:
            distance_any_mean = distance_any_std = distance_any_max = np.nan

        # Distance statistics - smaller only
        if all_distances_smaller:
            distance_smaller_mean = np.mean(all_distances_smaller)
            distance_smaller_std = np.std(all_distances_smaller)
            distance_smaller_max = np.max(all_distances_smaller)
        else:
            distance_smaller_mean = distance_smaller_std = distance_smaller_max = np.nan

        print(f"{avg_failures:<7.1f} {std_failures:<7.1f} {avg_failure_rate:<9.1f} {std_failure_rate:<9.1f} {avg_total_heavy:<8.1f} {avg_failed_heavy:<9.1f} "
              f"{recovery_mean:<7.1f} {recovery_med:<8.0f} {recovery_max:<8.0f} "
              f"{distance_any_mean:<8.1f} {distance_any_std:<10.1f} {distance_any_max:<10.0f} "
              f"{distance_smaller_mean:<10.1f} {distance_smaller_std:<12.1f} {distance_smaller_max:<12.0f}")

        return {
            'avg_failures': avg_failures,
            'failure_rate_pct': avg_failure_rate,
            'avg_recovery': recovery_mean,
            'n_groups': n_groups
        }


def run_comprehensive_medical_analysis(df_medical, epsilon_values=None, n_trials=30, log_file=None):
    """Run comprehensive medical analysis with all methods, fixed gamma=0.5, and 1.6x heavy threshold."""

    if epsilon_values is None:
        epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    # Set up logging if requested
    if log_file is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f"medical_comprehensive_analysis_gamma05_{timestamp}.txt"

    # Redirect output to both console and file
    original_stdout = sys.stdout
    tee = TeeOutput(log_file)
    sys.stdout = tee

    try:
        print("COMPREHENSIVE MEDICAL ANALYSIS - ALL METHODS, FIXED γ=0.5, HEAVY THRESHOLD=1.6x")
        print(f"Log file: {log_file}")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("="*100)

        # Define all medical-specific grouping methods
        methods = [
            ('Demographics', lambda allocator, df: allocator.create_demographics_groups(df, min_size=6)),
            ('Gender-Smoking', lambda allocator, df: allocator.create_gender_smoking_groups(df, min_size=6)),
            ('Age Groups', lambda allocator, df: allocator.create_age_groups(df, n_groups=30, min_size=6)),
            ('BMI Groups', lambda allocator, df: allocator.create_bmi_groups(df, n_groups=30, min_size=6)),
            ('Mallampati Score', lambda allocator, df: allocator.create_mallampati_groups(df, min_size=6)),
            ('ASA Physical Status', lambda allocator, df: allocator.create_asa_groups(df, min_size=6)),
            ('Preoperative Pain', lambda allocator, df: allocator.create_preop_pain_groups(df, n_groups=30, min_size=6)),
            ('Surgery Size', lambda allocator, df: allocator.create_surgery_size_groups(df, min_size=6)),
            ('Covariate Forest 30', lambda allocator, df: allocator.create_causal_forest_groups(df, n_groups=30, min_size=6)),
            ('Covariate Forest 50', lambda allocator, df: allocator.create_causal_forest_groups(df, n_groups=50, min_size=6)),
            ('Propensity Score', lambda allocator, df: allocator.create_propensity_groups(df, n_groups=50, min_size=6))
        ]

        all_results = {}

        for method_name, method_func in methods:
            print(f"\n{'='*120}")
            print(f"ANALYZING MEDICAL METHOD: {method_name}")
            print("="*120)

            method_results = []

            for eps in epsilon_values:
                print(f"\n{'='*100}")
                print(f"METHOD: {method_name} | EPSILON = {eps}")
                print("="*100)

                # Initialize allocator with fixed gamma=0.5 and 1.6x heavy threshold
                allocator = MedicalCATEAllocator(epsilon=eps, gamma=0.5, heavy_multiplier=1.6)
                df_processed = allocator.process_medical_data(df_medical)

                try:
                    # Create groups using this method
                    groups = method_func(allocator, df_processed)

                    if len(groups) < 3:
                        print(f"Too few groups ({len(groups)}) for {method_name} with ε = {eps} - skipping")
                        continue

                    groups = allocator.normalize_cates(groups)

                    # Show CATE distribution
                    allocator.plot_cate_distribution(groups, f" ({method_name}, ε={eps})")

                    # Run analysis for this epsilon and method
                    trial_data = allocator.analyze_method(groups, eps, n_trials)

                    # Print method summary
                    stats = allocator.print_method_summary(method_name, trial_data, len(groups), eps)

                    epsilon_result = {
                        'method': method_name,
                        'epsilon': eps,
                        'sqrt_epsilon': np.sqrt(eps),
                        'gamma': 0.5,
                        'rho': 0.5 * np.sqrt(eps),
                        'groups': groups,
                        'trial_data': trial_data,
                        'stats': stats
                    }

                    method_results.append(epsilon_result)

                except Exception as e:
                    print(f"Error with {method_name} at ε = {eps}: {e}")
                    continue

            all_results[method_name] = method_results

            # Add method-specific summary table after all epsilons for this method
            if method_results:
                print(f"\n{'='*120}")
                print(f"METHOD SUMMARY - {method_name} - ALL EPSILON VALUES")
                print("="*120)
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for eps_result in method_results:
                    eps = eps_result['epsilon']
                    sqrt_eps = eps_result['sqrt_epsilon']
                    gamma = eps_result['gamma']
                    rho = eps_result['rho']
                    n_groups = len(eps_result['groups'])
                    stats = eps_result['stats']

                    print(f"{eps:<8} {sqrt_eps:<10.6f} {gamma:<6} {rho:<10.6f} "
                          f"{n_groups:<8} {stats['avg_failures']:<8.1f} {stats['failure_rate_pct']:<8.1f} "
                          f"{stats['avg_recovery']:<8.1f}")
                print("="*120)

        # Create comprehensive summary across all methods and epsilon values
        print(f"\n{'='*200}")
        print("COMPREHENSIVE SUMMARY - ALL MEDICAL METHODS AND EPSILON VALUES")
        print("="*200)

        # Create summary table
        summary_data = []

        for method_name, method_results in all_results.items():
            if not method_results:
                continue

            print(f"\n{'-'*100}")
            print(f"MEDICAL METHOD: {method_name}")
            print("-"*100)

            for eps_result in method_results:
                eps = eps_result['epsilon']
                sqrt_eps = eps_result['sqrt_epsilon']
                gamma = eps_result['gamma']
                rho = eps_result['rho']
                n_groups = len(eps_result['groups'])
                stats = eps_result['stats']

                summary_data.append({
                    'method': method_name,
                    'epsilon': eps,
                    'sqrt_eps': sqrt_eps,
                    'gamma': gamma,
                    'rho': rho,
                    'avg_failures': stats['avg_failures'],
                    'failure_rate_pct': stats['failure_rate_pct'],
                    'avg_recovery': stats['avg_recovery'],
                    'n_groups': stats['n_groups']
                })

            # Print method-specific table
            method_data = [d for d in summary_data if d['method'] == method_name]
            if method_data:
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for data in method_data:
                    print(f"{data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                          f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                          f"{data['avg_recovery']:<8.1f}")

        # Overall summary table
        print(f"\n{'='*200}")
        print("OVERALL SUMMARY TABLE - ALL MEDICAL METHODS COMBINED")
        print("="*200)
        print(f"{'Method':<18} {'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
        print("-" * 100)

        for data in summary_data:
            print(f"{data['method']:<18} {data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                  f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                  f"{data['avg_recovery']:<8.1f}")

        # Medical-specific analysis insights
        print(f"\n{'='*100}")
        print("KEY INSIGHTS FOR MEDICAL DATASET")
        print("="*100)

        # Find best and worst performing methods for medical data
        if summary_data:
            # Average performance across all epsilon values per method
            method_performance = {}
            for method_name in all_results.keys():
                method_data = [d for d in summary_data if d['method'] == method_name]
                if method_data:
                    avg_failure_rate = np.mean([d['failure_rate_pct'] for d in method_data])
                    method_performance[method_name] = avg_failure_rate

            if method_performance:
                best_method = min(method_performance, key=method_performance.get)
                worst_method = max(method_performance, key=method_performance.get)

                print(f"BEST PERFORMING MEDICAL METHOD: {best_method}")
                print(f"  Average failure rate: {method_performance[best_method]:.1f}%")

                print(f"\nWORST PERFORMING MEDICAL METHOD: {worst_method}")
                print(f"  Average failure rate: {method_performance[worst_method]:.1f}%")

                print(f"\nMEDICAL METHOD RANKING (by average failure rate):")
                sorted_methods = sorted(method_performance.items(), key=lambda x: x[1])
                for i, (method, rate) in enumerate(sorted_methods, 1):
                    print(f"  {i}. {method}: {rate:.1f}%")

        # Effect of epsilon on medical data
        print(f"\nEFFECT OF EPSILON ON MEDICAL DATA:")
        epsilon_performance = {}
        for eps in epsilon_values:
            eps_data = [d for d in summary_data if d['epsilon'] == eps]
            if eps_data:
                avg_failure_rate = np.mean([d['failure_rate_pct'] for d in eps_data])
                epsilon_performance[eps] = avg_failure_rate

        if epsilon_performance:
            print(f"{'Epsilon':<10} {'Avg Failure Rate':<15} {'ρ = 0.5√ε':<12}")
            print("-" * 40)
            for eps in sorted(epsilon_performance.keys()):
                rho = 0.5 * np.sqrt(eps)
                print(f"{eps:<10} {epsilon_performance[eps]:<15.1f} {rho:<12.6f}")

        return all_results, summary_data

    finally:
        # Restore original stdout and close log file
        sys.stdout = original_stdout
        tee.close()


# Example usage for medical dataset
if __name__ == "__main__":
    # Load medical dataset and run the analysis
    df_medical = pd.read_stata('Licorice Gargle.dta')

    # Run comprehensive medical analysis with same parameters as NSW
    epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    results, summary = run_comprehensive_medical_analysis(
        df_medical,
        epsilon_values=epsilon_values,
        n_trials=30,
        log_file="postop_comprehensive_analysis_gamma05_heavy16.txt"
    )