# POSTOP CATE ANALYSIS

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

def theoretical_bound(m, beta, N, delta, OPT):
    """Compute theoretical bound (1 - (N*ln(2N/delta)/m)^beta) * OPT"""
    term = N * np.log(2 * N / delta) / m
    if term >= 1:
        return 0  # Bound becomes meaningless
    return (1 - term**beta) * OPT

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class MedicalSampleSizeAnalyzer:
    """Medical CATE allocation with sample size analysis and theoretical bounds."""

    def __init__(self, random_seed=42):
        self.random_seed = random_seed
        np.random.seed(random_seed)
        print(f"Medical Sample Size Analyzer initialized with seed {random_seed}")

    def process_medical_data(self, df, outcome_col='postop4hour_throatpain', treatment_col='treat'):
        """Process medical dataset - SAME AS ORIGINAL MEDICAL CODE."""
        print(f"Processing medical data with {len(df)} patients")
        print(f"Available columns: {list(df.columns)}")

        df_processed = df.copy()

        # Check for required columns
        if treatment_col not in df_processed.columns:
            raise ValueError(f"Missing required treatment column: {treatment_col}")
        if outcome_col not in df_processed.columns:
            raise ValueError(f"Missing required outcome column: {outcome_col}")

        # Set up treatment and outcome
        df_processed['treatment'] = df_processed[treatment_col]
        df_processed['outcome'] = df_processed[outcome_col]

        # Create baseline risk using preoperative pain if available
        if 'preop_pain' in df_processed.columns:
            df_processed['baseline_risk'] = df_processed['preop_pain']
        else:
            df_processed['baseline_risk'] = 0

        # Clean data
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} patients")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome (4-hour throat pain) statistics: mean={df_processed['outcome'].mean():.2f}, std={df_processed['outcome'].std():.2f}")

        if 'baseline_risk' in df_processed.columns:
            print(f"Baseline (preop pain) stats: mean={df_processed['baseline_risk'].mean():.2f}, std={df_processed['baseline_risk'].std():.2f}")

        return df_processed

    def create_demographics_groups(self, df, min_size=6):
        """Create groups by demographics - SAME AS MEDICAL CODE."""
        print(f"Creating medical demographics groups")

        demo_features = ['preop_gender', 'preop_smoking', 'preop_asa']
        available_features = [col for col in demo_features if col in df.columns]

        if not available_features:
            print("No demographic variables found")
            return []

        print(f"Using demographic features: {available_features}")
        if len(available_features) > 3:
            available_features = available_features[:3]

        df_clean = df.dropna(subset=available_features)
        print(f"After removing missing values: {len(df_clean)}/{len(df)} patients")

        if len(df_clean) == 0:
            return []

        groups = []
        unique_combinations = df_clean[available_features].drop_duplicates()
        print(f"Found {len(unique_combinations)} unique demographic combinations")

        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_age_groups(self, df, n_groups=30, min_size=6):
        """Create age groups - SAME AS MEDICAL CODE."""
        print(f"Creating age groups (target: {n_groups})")

        if 'preop_age' not in df.columns:
            print("No age variable found")
            return []

        age = df['preop_age'].fillna(df['preop_age'].median())
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(age, percentiles)
        bins = np.digitize(age, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_group_{i}',
                    'indices': indices,
                    'type': 'age'
                })

        print(f"Created {len(groups)} age groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_bmi_groups(self, df, n_groups=30, min_size=6):
        """Create BMI groups - SAME AS MEDICAL CODE."""
        print(f"Creating BMI groups (target: {n_groups})")

        if 'preop_calcbmi' not in df.columns:
            print("No BMI variable found")
            return []

        bmi = df['preop_calcbmi'].fillna(df['preop_calcbmi'].median())
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(bmi, percentiles)
        bins = np.digitize(bmi, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'bmi_group_{i}',
                    'indices': indices,
                    'type': 'bmi'
                })

        print(f"Created {len(groups)} BMI groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_preop_pain_groups(self, df, n_groups=30, min_size=6):
        """Create preop pain groups - SAME AS MEDICAL CODE."""
        print(f"Creating preoperative pain groups (target: {n_groups})")

        if 'preop_pain' not in df.columns:
            print("No preoperative pain data available")
            return []

        pain = df['preop_pain'].fillna(0)

        if (pain == 0).mean() > 0.3:
            zero_pain = df.index[pain == 0].tolist()
            groups = []
            if len(zero_pain) >= min_size:
                groups.append({
                    'id': 'no_preop_pain',
                    'indices': zero_pain,
                    'type': 'preop_pain'
                })

            positive_pain = pain[pain > 0]
            if len(positive_pain) > 0:
                percentiles = np.linspace(0, 100, n_groups)
                cuts = np.percentile(positive_pain, percentiles)

                for i in range(len(cuts) - 1):
                    mask = (pain > cuts[i]) & (pain <= cuts[i + 1])
                    indices = df.index[mask].tolist()
                    if len(indices) >= min_size:
                        groups.append({
                            'id': f'preop_pain_level_{i}',
                            'indices': indices,
                            'type': 'preop_pain'
                        })
        else:
            percentiles = np.linspace(0, 100, n_groups + 1)
            cuts = np.percentile(pain, percentiles)
            bins = np.digitize(pain, cuts) - 1

            groups = []
            for i in range(n_groups):
                indices = df.index[bins == i].tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'preop_pain_level_{i}',
                        'indices': indices,
                        'type': 'preop_pain'
                    })

        print(f"Created {len(groups)} preoperative pain groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_asa_groups(self, df, min_size=6):
        """Create ASA groups - SAME AS MEDICAL CODE."""
        print(f"Creating ASA physical status groups")

        if 'preop_asa' not in df.columns:
            print("No ASA physical status data")
            return []

        groups = []
        for asa_status in df['preop_asa'].unique():
            if pd.isna(asa_status):
                continue

            indices = df[df['preop_asa'] == asa_status].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'asa_class_{asa_status}',
                    'indices': indices,
                    'type': 'asa_status'
                })

        print(f"Created {len(groups)} ASA status groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_covariate_forest_groups(self, df, n_groups=30, min_size=6):
        """Create covariate forest groups - SAME AS MEDICAL CODE."""
        print(f"Creating covariate-based forest groups (target: {n_groups})")

        feature_cols = ['preop_age', 'preop_calcbmi', 'preop_gender', 'preop_asa',
                       'preop_mallampati', 'preop_smoking', 'preop_pain', 'intraop_surgerysize']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for covariate clustering")
            return []

        X = df[available_features].copy()

        for col in X.columns:
            if X[col].dtype == 'object':
                le = LabelEncoder()
                X[col] = X[col].fillna('missing')
                X[col] = le.fit_transform(X[col])
            else:
                if X[col].isna().any():
                    X[col] = X[col].fillna(X[col].median())

        cluster_features = StandardScaler().fit_transform(X.values)
        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'covariate_cluster_{i}',
                    'indices': indices,
                    'type': 'covariate_cluster'
                })

        print(f"Created {len(groups)} covariate-based groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure balance and compute CATE - SAME AS MEDICAL CODE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            # Reverse sign: negative CATE means treatment reduces pain (beneficial)
            cate = -(treated_outcomes.mean() - control_outcomes.mean())

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1] - SAME AS MEDICAL CODE."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def simulate_sampling_trial(self, groups, sample_size, trial_seed):
        """Simulate sampling trial."""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Initialize tau estimates
        tau_estimates = np.zeros(n_groups)
        sample_counts = np.zeros(n_groups)

        # Perform sampling: choose group uniformly, sample Bernoulli(tau(u))
        for _ in range(sample_size):
            group_idx = np.random.randint(n_groups)
            sample = np.random.binomial(1, tau_true[group_idx])

            sample_counts[group_idx] += 1
            if sample_counts[group_idx] == 1:
                tau_estimates[group_idx] = sample
            else:
                tau_estimates[group_idx] = ((sample_counts[group_idx] - 1) * tau_estimates[group_idx] + sample) / sample_counts[group_idx]

        # Groups with no samples get estimate 0
        tau_estimates[sample_counts == 0] = 0

        return tau_estimates, sample_counts

    def analyze_sample_size_performance(self, groups, sample_sizes, budget_percentages, n_trials=50):
        """Analyze performance vs sample size."""
        print(f"Analyzing sample size performance with {len(groups)} groups")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Calculate budgets
        budgets = [max(1, int(p * n_groups)) for p in budget_percentages]
        print(f"Budget percentages {budget_percentages} → K values {budgets}")

        # Calculate optimal values
        optimal_values = {}
        for i, K in enumerate(budgets):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_values[budget_percentages[i]] = np.sum(tau_true[optimal_indices])

        # Run trials
        results = {bp: {'sample_sizes': [], 'values': [], 'stds': []} for bp in budget_percentages}

        for sample_size in sample_sizes:
            print(f"  Sample size {sample_size}...")

            budget_trial_values = {bp: [] for bp in budget_percentages}

            for trial in range(n_trials):
                tau_estimates, sample_counts = self.simulate_sampling_trial(groups, sample_size, trial)

                for i, K in enumerate(budgets):
                    bp = budget_percentages[i]

                    # Select top K based on estimates
                    selected_indices = np.argsort(tau_estimates)[-K:]

                    # Compute realized value with true tau
                    realized_value = np.sum(tau_true[selected_indices])
                    budget_trial_values[bp].append(realized_value)

            # Store results
            for bp in budget_percentages:
                results[bp]['sample_sizes'].append(sample_size)
                results[bp]['values'].append(np.mean(budget_trial_values[bp]))
                results[bp]['stds'].append(np.std(budget_trial_values[bp]))

        return results, optimal_values

    def plot_sample_size_analysis(self, results, optimal_values, method_name, budget_percentages, n_groups):
        """Create 6 plots (one per budget) for sample size analysis with theoretical bounds"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()

        # Calculate parameters for theoretical bounds
        delta = 0.05

        print(f"\nPlotting {method_name} (N={n_groups})")
        print("="*60)

        for i, bp in enumerate(budget_percentages):
            ax = axes[i]

            # Get data for this budget
            sample_sizes = results[bp]['sample_sizes']
            values = results[bp]['values']
            stds = results[bp]['stds']
            optimal_val = optimal_values[bp]

            # Normalize all values by optimal value
            values_norm = np.array(values) / optimal_val
            stds_norm = np.array(stds) / optimal_val

            # Plot empirical performance curve 
            ax.errorbar(sample_sizes, values_norm, yerr=stds_norm,
                      marker='o', capsize=5, capthick=3, linewidth=6, markersize=8,
                      label='Empirical data', color='blue', alpha=0.8)

            # Plot optimal value (normalized to 1)
            ax.axhline(y=1.0, color='black', linestyle=':', linewidth=2,
                      label='Optimal (1.0)', alpha=0.8)

            # Create smooth curves for plotting theoretical bounds
            m_smooth = np.linspace(min(sample_sizes), max(sample_sizes), 200)

            # Plot reference curves and normalize
            ref_curve_05 = [theoretical_bound(m, 0.5, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]
            ref_curve_10 = [theoretical_bound(m, 1.0, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]

            ax.plot(m_smooth, ref_curve_05, 'red', linestyle=(0, (3, 2)), linewidth=6,
                  label='FullCATE', alpha=0.8)
            ax.plot(m_smooth, ref_curve_10, 'green', linestyle=(0, (3, 1, 1, 1)), linewidth=6,
                  label='ALLOC', alpha=0.8)

            # Set labels 
            ax.set_xlabel('Sample size', fontsize=23)
            ax.set_ylabel('Normalized allocation value', fontsize=23)
            ax.set_title(f'Budget = {bp*100:.0f}% (K={max(1, int(bp * n_groups))})', fontsize=24, fontweight='bold')

            ax.legend(fontsize=21, framealpha=0.9)
            ax.grid(True, alpha=0.4, linewidth=1)

            # Make tick labels larger
            ax.tick_params(axis='both', which='major', labelsize=16, width=1.5, length=5)

            y_min = 0.2
            y_max = 1.05  # Slightly above optimal
            ax.set_ylim(y_min, y_max)

            # Keep axes normal weight
            for spine in ax.spines.values():
                spine.set_linewidth(1.5)

        plt.suptitle(f'{method_name} (N={n_groups})', fontsize=24, fontweight='bold')
        plt.tight_layout()

        clean_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_')
        pdf_filename = f"{clean_name}_N{n_groups}_sample_size_analysis.pdf"
        plt.savefig(pdf_filename, format='pdf', dpi=300, bbox_inches='tight')
        print(f"Saved plot as: {pdf_filename}")

        plt.show()

        print(f"Plot complete for {method_name}")


def run_medical_sample_size_analysis(df_medical, sample_size_range=None, budget_percentages=None, n_trials=50,
                                   outcome_col='postop4hour_throatpain', treatment_col='treat'):
    """Run sample size analysis on Medical dataset with theoretical bounds."""

    if sample_size_range is None:
        sample_size_range = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]

    if budget_percentages is None:
        budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    print("MEDICAL SAMPLE SIZE ANALYSIS - EMPIRICAL VS THEORETICAL BOUNDS")
    print(f"Sample sizes: {sample_size_range}")
    print(f"Budget percentages: {budget_percentages}")
    print(f"Trials per sample size: {n_trials}")
    print("="*80)

    # Define Medical grouping methods
    methods = [
        ('Demographics', lambda analyzer, df: analyzer.create_demographics_groups(df, min_size=6)),
        ('Age Groups', lambda analyzer, df: analyzer.create_age_groups(df, n_groups=30, min_size=6)),
        ('BMI Groups', lambda analyzer, df: analyzer.create_bmi_groups(df, n_groups=30, min_size=6)),
        ('Preoperative Pain', lambda analyzer, df: analyzer.create_preop_pain_groups(df, n_groups=30, min_size=6)),
        ('ASA Physical Status', lambda analyzer, df: analyzer.create_asa_groups(df, min_size=6)),
        ('Covariate Forest', lambda analyzer, df: analyzer.create_covariate_forest_groups(df, n_groups=30, min_size=6))
    ]

    all_results = {}

    for method_name, method_func in methods:
        print(f"\n{'='*80}")
        print(f"ANALYZING MEDICAL METHOD: {method_name}")
        print("="*80)

        try:
            analyzer = MedicalSampleSizeAnalyzer()
            df_processed = analyzer.process_medical_data(df_medical, outcome_col=outcome_col, treatment_col=treatment_col)

            groups = method_func(analyzer, df_processed)

            if len(groups) < 10:
                print(f"Too few groups ({len(groups)}) for {method_name} - skipping")
                continue

            groups = analyzer.normalize_cates(groups)

            # Run sample size analysis
            results, optimal_values = analyzer.analyze_sample_size_performance(
                groups, sample_size_range, budget_percentages, n_trials
            )

            all_results[method_name] = {
                'results': results,
                'optimal_values': optimal_values,
                'n_groups': len(groups)
            }

            # Create plots with theoretical bounds
            print(f"Creating plots for {method_name}...")
            analyzer.plot_sample_size_analysis(
                results, optimal_values, method_name, budget_percentages, len(groups)
            )

            # Print summary
            print(f"\nSummary for {method_name}:")
            print(f"Number of groups: {len(groups)}")
            print("Optimal values by budget:")
            for bp in budget_percentages:
                print(f"  {bp*100:.0f}%: {optimal_values[bp]:.3f}")

        except Exception as e:
            print(f"Error with {method_name}: {e}")
            continue

    return all_results

# Example usage
if __name__ == "__main__":
    # Load medical dataset
    df_medical = pd.read_stata('Licorice Gargle.dta')

    # Run sample size analysis with theoretical bounds
    sample_sizes = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]
    budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    results = run_medical_sample_size_analysis(
        df_medical,
        sample_size_range=sample_sizes,
        budget_percentages=budget_percentages,
        n_trials=50
    )