import torch
import numpy as np
import pandas as pd
import warnings
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

# --- MODIFICATION 1: Import the new SD_VI class ---
# from OCR_AVI import OCR_AVI
from SD_VI import SD_VI

warnings.filterwarnings('ignore')


# The `load_data` function remains unchanged as it's for data preprocessing.
def load_data(file_path="breast-cancer.txt", test_size=0.2, random_state=42):
    """
    Load heart disease data from file with robust data cleaning
    Supports different formats including libsvm format
    """
    # ... (no changes needed in this function, keeping it as is)
    try:
        print(f"Attempting to load data from: {file_path}")

        # First try to detect if this is libsvm format by checking the first few lines
        with open(file_path, 'r') as f:
            first_line = f.readline().strip()

        # Check if it's libsvm format (contains ":" in feature columns)
        parts = first_line.split()
        is_libsvm = any(':' in part for part in parts[1:]) if len(parts) > 1 else False

        if is_libsvm:
            print("Detected libsvm format, parsing accordingly...")

            with open(file_path, 'r') as f:
                lines = f.readlines()

            data = []
            labels = []
            max_feature_idx = 0

            # First pass: determine number of features
            for line in lines:
                parts = line.strip().split()
                if not parts:
                    continue

                for part in parts[1:]:
                    if ':' in part:
                        try:
                            idx_str, val_str = part.split(':')
                            idx = int(idx_str)
                            max_feature_idx = max(max_feature_idx, idx)
                        except (ValueError, IndexError):
                            continue

            n_features = max_feature_idx
            print(f"Detected {n_features} features in libsvm format")

            # Second pass: extract data
            for line_num, line in enumerate(lines):
                parts = line.strip().split()
                if not parts:
                    continue

                try:
                    # Extract label (first element)
                    label = float(parts[0])
                    labels.append(label)

                    # Extract features
                    features = {}
                    for part in parts[1:]:
                        if ':' in part:
                            try:
                                idx_str, val_str = part.split(':')
                                idx = int(idx_str)
                                val = float(val_str)
                                features[idx] = val
                            except (ValueError, IndexError):
                                continue

                    # Create feature vector (libsvm features typically start from index 1)
                    feature_vector = [features.get(i, 0.0) for i in range(1, n_features + 1)]
                    data.append(feature_vector)

                except (ValueError, IndexError) as e:
                    print(f"Skipping invalid line {line_num + 1}: {line.strip()[:50]}...")
                    continue

            print(f"Successfully parsed {len(data)} samples from libsvm format")

            # Convert to numpy arrays
            X = np.array(data, dtype=np.float64)
            y = np.array(labels, dtype=np.float64)

        else:
            # Try regular delimited format
            print("Trying regular delimited format...")
            df = None
            for delimiter in [' ', '\t', ',', ';']:
                try:
                    df = pd.read_csv(file_path, delimiter=delimiter, header=None)
                    if df.shape[1] > 1:  # Valid data
                        print(f"Successfully loaded data with delimiter '{delimiter}'")
                        break
                except Exception as e:
                    continue

            if df is None or df.shape[1] <= 1:
                raise ValueError("Could not parse file in any recognized format")

            print(f"Initial data shape: {df.shape}")

            # Clean the data
            print("Cleaning data...")

            # Replace common missing value indicators
            df = df.replace(['?', '', 'NA', 'na', 'NaN'], np.nan)

            # Convert all columns to numeric
            for col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')

            print(f"NaN count after conversion: {df.isnull().sum().sum()}")

            # Drop rows with any NaN values
            initial_rows = len(df)
            df = df.dropna()
            final_rows = len(df)

            if final_rows < initial_rows:
                print(f"Dropped {initial_rows - final_rows} rows with missing values")

            if len(df) == 0:
                raise ValueError("No valid data remaining after cleaning")

            # Extract features and target
            X = df.iloc[:, :-1].values.astype(np.float64)
            y = df.iloc[:, -1].values.astype(np.float64)

        print(f"Final data shape: X={X.shape}, y={y.shape}")
        print(f"Unique labels: {np.unique(y)}")
        print(f"Data types: X={X.dtype}, y={y.dtype}")

        # Verify data quality
        if np.any(np.isnan(X)) or np.any(np.isinf(X)):
            print("Warning: X contains NaN or inf values, cleaning...")
            X = np.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6)

        if np.any(np.isnan(y)) or np.any(np.isinf(y)):
            print("Warning: y contains NaN or inf values, cleaning...")
            y = np.nan_to_num(y, nan=0.0, posinf=1.0, neginf=0.0)

        # Convert to PyTorch tensors
        X = torch.tensor(X, dtype=torch.float64)
        y = torch.tensor(y, dtype=torch.float64)

        # Ensure binary labels (0/1)
        unique_labels = torch.unique(y)
        print(f"Unique labels before mapping: {unique_labels.tolist()}")

        if len(unique_labels) == 2:
            # Map to 0/1
            y_binary = torch.zeros_like(y)
            y_binary[y == unique_labels[1]] = 1
            y = y_binary
        elif len(unique_labels) > 2:
            # Multi-class to binary: split at median
            median_val = torch.median(y)
            y = (y > median_val).float()

        print(f"After processing - Labels: {torch.unique(y).tolist()}")
        print(f"Class distribution: {torch.bincount(y.long()).tolist()}")

        # Verify we have both classes
        if len(torch.unique(y)) < 2:
            print("Warning: Only one class found, creating balanced synthetic labels")
            n_samples = len(y)
            y = torch.randint(0, 2, (n_samples,), dtype=torch.float64)

        # Split data
        try:
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=test_size, random_state=random_state,
                stratify=y if len(torch.unique(y)) > 1 else None
            )
        except Exception as e:
            print(f"Stratified split failed: {e}, using random split")
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=test_size, random_state=random_state
            )

        # Standardize features
        scaler = StandardScaler()
        X_train_scaled = torch.tensor(
            scaler.fit_transform(X_train.numpy()),
            dtype=torch.float64
        )
        X_test_scaled = torch.tensor(
            scaler.transform(X_test.numpy()),
            dtype=torch.float64
        )

        print(f"Final split - Training: {X_train_scaled.shape[0]} samples, Test: {X_test_scaled.shape[0]} samples")
        print(f"Features: {X_train_scaled.shape[1]}")

        return X_train_scaled, X_test_scaled, y_train, y_test, scaler

    except FileNotFoundError:
        print(f"File {file_path} not found. Using synthetic heart disease data...")
        # Generate synthetic data similar to heart disease dataset
        from sklearn.datasets import make_classification

        X, y = make_classification(
            n_samples=300, n_features=13, n_informative=8, n_redundant=2,
            n_clusters_per_class=1, random_state=42, class_sep=1.2
        )

        X = torch.tensor(X, dtype=torch.float64)
        y = torch.tensor(y, dtype=torch.float64)

        # Split and standardize
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=random_state, stratify=y
        )

        scaler = StandardScaler()
        X_train_scaled = torch.tensor(
            scaler.fit_transform(X_train.numpy()),
            dtype=torch.float64
        )
        X_test_scaled = torch.tensor(
            scaler.transform(X_test.numpy()),
            dtype=torch.float64
        )

        print(f"Generated synthetic heart disease data")
        print(f"Training set: {X_train_scaled.shape[0]} samples")
        print(f"Test set: {X_test_scaled.shape[0]} samples")
        print(f"Features: {X_train_scaled.shape[1]}")

        return X_train_scaled, X_test_scaled, y_train, y_test, scaler


# The `compute_metrics` function also needs a small adjustment for the final objective value.
def compute_metrics(model, X_train, X_test, y_train, y_test):
    """Compute AUC, MSE, ECE, and the final objective value"""
    results = {}

    with torch.no_grad():
        # Predictions
        y_train_pred_proba = model.predict_proba(X_train).numpy()
        y_test_pred_proba = model.predict_proba(X_test).numpy()

        # Convert labels to numpy
        y_train_np = y_train.numpy()
        y_test_np = y_test.numpy()

        # AUC
        try:
            results['train_auc'] = roc_auc_score(y_train_np, y_train_pred_proba)
            results['test_auc'] = roc_auc_score(y_test_np, y_test_pred_proba)
        except ValueError:
            results['train_auc'] = 0.5
            results['test_auc'] = 0.5

        # MSE
        try:
            results['train_mse'] = model.compute_mse(X_train, y_train)
            results['test_mse'] = model.compute_mse(X_test, y_test)
        except Exception as e:
            print(f"MSE computation failed: {e}")
            results['train_mse'] = float('nan')
            results['test_mse'] = float('nan')

        # ECE
        try:
            results['train_ece'] = model.compute_ece(X_train, y_train, n_bins=10)
            results['test_ece'] = model.compute_ece(X_test, y_test, n_bins=10)
        except Exception as e:
            print(f"ECE computation failed: {e}")
            results['train_ece'] = float('nan')
            results['test_ece'] = float('nan')

        # --- MODIFICATION 2: Report the new objective value ---
        # Changed 'final_elbo' to 'final_objective' for clarity
        if hasattr(model, 'objective_values') and len(model.objective_values) > 0:
            results['final_objective'] = model.objective_values[-1]
        else:
            results['final_objective'] = np.nan

    return results


def run_multiple_experiments(data_file, n_runs=5, random_seed=42, model_params=None):
    """Run multiple independent experiments, each with a new data split."""
    print(f"\nRunning {n_runs} independent experiments (with new data split each time)...")
    print("=" * 60)

    all_results = []

    for run in range(n_runs):
        run_seed = random_seed + run
        print(f"Run {run + 1}/{n_runs} with random_seed = {run_seed}...")

        torch.manual_seed(run_seed)
        np.random.seed(run_seed)
        X_train, X_test, y_train, y_test, _ = load_data(
            file_path=data_file, random_state=run_seed
        )

        # --- MODIFICATION 3: Instantiate the new SD_VI class ---
        model = SD_VI(
            n_features=X_train.shape[1],
            random_state=run_seed,
            **model_params
        )

        # The fit call is simpler as initialization is handled internally
        model.fit(X_train, y_train, verbose=False)
        results = compute_metrics(model, X_train, X_test, y_train, y_test)
        results['run'] = run + 1
        all_results.append(results)

    print("\n" + "=" * 60)
    print("RESULTS - Median (2.5%, 97.5% quantiles)")
    print("=" * 60)

    # --- MODIFICATION 4: Update metric names for reporting ---
    metrics = ['test_auc', 'test_mse', 'test_ece', 'final_objective']
    summary = {}
    for metric in metrics:
        values = [r[metric] for r in all_results if not np.isnan(r[metric])]
        if len(values) > 0:
            percentiles = np.percentile(values, [2.5, 50, 97.5])
            summary[metric] = {'median': percentiles[1], 'q2_5': percentiles[0], 'q97_5': percentiles[2],
                               'values': values}

    metric_names = {
        'test_auc': 'Test AUC',
        'test_mse': 'Test MSE',
        'test_ece': 'Test ECE',
        'final_objective': 'Final Objective'  # Changed from Final ELBO
    }

    def format_metric_value(median, q2_5, q97_5):
        if abs(median) >= 10:
            return f"{median:.1f} ({q2_5:.1f}, {q97_5:.1f})"
        elif abs(median) >= 1:
            return f"{median:.2f} ({q2_5:.2f}, {q97_5:.2f})"
        else:
            return f"{median:.4f} ({q2_5:.4f}, {q97_5:.4f})"

    for metric in metrics:
        if metric in summary:
            stats = summary[metric]
            formatted = format_metric_value(stats['median'], stats['q2_5'], stats['q97_5'])
            print(f"{metric_names[metric]:<17}: {formatted}")

    return all_results, summary


# --- MODIFICATION 5: Update the main function signature and parameters ---
def main(data_file="heart.txt", n_runs=5, lr_mu=0.01, lr_S=0.001,
         max_iter=1000, lambda1=1e-2, random_seed=42):
    print("SD-VI Classification (Proximal Spectral Optimization)")
    print("=" * 60)

    # Model parameters for the new SD_VI class
    model_params = {
        'lr_mu': lr_mu,
        'lr_S': lr_S,
        'max_iter': max_iter,
        'lambda1': lambda1
    }

    # Run multiple experiments
    all_results, summary = run_multiple_experiments(
        data_file=data_file,
        n_runs=n_runs,
        model_params=model_params,
        random_seed=random_seed
    )

    return all_results, summary


if __name__ == "__main__":
    # --- MODIFICATION 6: Update parameter controls for the new SD_VI model ---
    DATA_FILE = "heart.txt"
    N_RUNS = 5
    RANDOM_SEED = 3407

    # New SD-VI specific parameters. These will need tuning.
    # Learning rate for the mean vector
    LR_MU = 0.01
    # Learning rate (step-size η_S) for the covariance matrix update
    LR_S = 0.001
    # Strength of the L1 spectral penalty (soft-thresholding level)
    LAMBDA_1 = 0.01
    # Maximum iterations for the optimization
    MAX_ITER = 800

    # Run experiments
    results, summary = main(
        data_file=DATA_FILE,
        n_runs=N_RUNS,
        lr_mu=LR_MU,
        lr_S=LR_S,
        max_iter=MAX_ITER,
        lambda1=LAMBDA_1,
        random_seed=RANDOM_SEED
    )