#!/usr/bin/env python
"""
Command Line Interface for LapBoost

This module provides a command-line interface to the LapBoost package,
allowing users to train models, make predictions, and visualize results
directly from the terminal.
"""

import argparse
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    classification_report, 
    mean_squared_error, 
    r2_score
)

from lapboost import (
    LapBoostClassifier, 
    LapBoostRegressor,
    IterativeLapBoostClassifier,
    IterativeLapBoostRegressor
)
from lapboost.visualization.plots import (
    plot_decision_boundary, 
    plot_confidence_distribution,
    plot_learning_curves,
    plot_graph_structure
)


def load_data(labeled_path, unlabeled_path=None, target_col=None, test_size=0.2, random_state=42):
    """
    Load data from CSV files for semi-supervised learning.
    
    Parameters
    ----------
    labeled_path : str
        Path to CSV file containing labeled data
    unlabeled_path : str, optional
        Path to CSV file containing unlabeled data
    target_col : str
        Name of target column in labeled data
    test_size : float, default=0.2
        Proportion of labeled data to use for testing
    random_state : int, default=42
        Random seed for reproducibility
        
    Returns
    -------
    tuple
        X_train, X_test, y_train, y_test, X_unlabeled
    """
    # Load labeled data
    labeled_df = pd.read_csv(labeled_path)
    
    if target_col not in labeled_df.columns:
        raise ValueError(f"Target column '{target_col}' not found in labeled data")
    
    # Split features and target
    X_labeled = labeled_df.drop(columns=[target_col]).values
    y = labeled_df[target_col].values
    
    # Encode categorical targets if needed
    if X_labeled.dtype == object or (y.dtype == object and len(np.unique(y)) < 10):
        le = LabelEncoder()
        y = le.fit_transform(y)
        print(f"Target classes encoded: {list(zip(le.classes_, range(len(le.classes_))))}")
    
    # Split into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X_labeled, y, test_size=test_size, random_state=random_state, stratify=y if len(np.unique(y)) < 10 else None
    )
    
    # Load unlabeled data if provided
    if unlabeled_path:
        unlabeled_df = pd.read_csv(unlabeled_path)
        
        # If target column exists in unlabeled data, drop it
        X_unlabeled = unlabeled_df.drop(columns=[target_col], errors='ignore').values
    else:
        X_unlabeled = None
    
    print(f"Data loaded: {X_train.shape[0]} training samples, {X_test.shape[0]} test samples")
    if X_unlabeled is not None:
        print(f"Unlabeled data: {X_unlabeled.shape[0]} samples")
    
    return X_train, X_test, y_train, y_test, X_unlabeled


def train_and_evaluate(args):
    """Train a LapBoost model and evaluate it"""
    # Load data
    X_train, X_test, y_train, y_test, X_unlabeled = load_data(
        args.labeled_data, 
        args.unlabeled_data, 
        args.target_column, 
        args.test_size, 
        args.random_state
    )
    
    # Determine task type
    n_unique = len(np.unique(y_train))
    task_type = args.task_type or ('classification' if n_unique <= 10 else 'regression')
    
    print(f"\nTask type: {task_type}")
    print(f"Model type: {'iterative' if args.iterative else 'standard'}")
    
    # Create model
    if task_type == 'classification':
        if args.iterative:
            model = IterativeLapBoostClassifier(
                k_neighbors=args.k_neighbors,
                gamma=args.gamma,
                confidence_threshold=args.confidence_threshold,
                confidence_decay=args.confidence_decay,
                max_iter=args.max_iter,
                n_estimators=args.n_estimators,
                learning_rate=args.learning_rate,
                verbose=args.verbose,
                random_state=args.random_state
            )
        else:
            model = LapBoostClassifier(
                k_neighbors=args.k_neighbors,
                gamma=args.gamma,
                confidence_threshold=args.confidence_threshold,
                max_iter=args.max_iter,
                n_estimators=args.n_estimators,
                learning_rate=args.learning_rate,
                verbose=args.verbose,
                random_state=args.random_state
            )
    else:  # regression
        if args.iterative:
            model = IterativeLapBoostRegressor(
                k_neighbors=args.k_neighbors,
                gamma=args.gamma,
                confidence_threshold=args.confidence_threshold,
                confidence_decay=args.confidence_decay,
                max_iter=args.max_iter,
                n_estimators=args.n_estimators,
                learning_rate=args.learning_rate,
                verbose=args.verbose,
                random_state=args.random_state
            )
        else:
            model = LapBoostRegressor(
                k_neighbors=args.k_neighbors,
                gamma=args.gamma,
                confidence_threshold=args.confidence_threshold,
                max_iter=args.max_iter,
                n_estimators=args.n_estimators,
                learning_rate=args.learning_rate,
                verbose=args.verbose,
                random_state=args.random_state
            )
    
    # Train model
    print("\nTraining LapBoost model...")
    model.fit(X_train, y_train, X_unlabeled)
    
    # Make predictions
    y_pred = model.predict(X_test)
    
    # Evaluate
    print("\nEvaluation results:")
    if task_type == 'classification':
        accuracy = accuracy_score(y_test, y_pred)
        print(f"Test accuracy: {accuracy:.4f}")
        print("\nClassification report:")
        print(classification_report(y_test, y_pred))
        
        if hasattr(model, 'predict_proba'):
            y_proba = model.predict_proba(X_test)
            confidences = np.max(y_proba, axis=1)
            
            if args.visualize:
                print("\nGenerating visualizations...")
                
                # Plot confidence distribution
                fig1 = plot_confidence_distribution(
                    confidences, y_test, y_pred,
                    title="Prediction Confidence Distribution"
                )
                
                # Save if output directory specified
                if args.output_dir:
                    os.makedirs(args.output_dir, exist_ok=True)
                    fig1.savefig(os.path.join(args.output_dir, "confidence_distribution.png"), dpi=300)
                
                # Plot decision boundary if 2D
                if X_test.shape[1] <= 2:
                    fig2 = plot_decision_boundary(
                        model, X_test, y_test,
                        title="LapBoost Decision Boundary"
                    )
                    
                    if args.output_dir:
                        fig2.savefig(os.path.join(args.output_dir, "decision_boundary.png"), dpi=300)
                
                # Plot learning curves for iterative model
                if args.iterative and hasattr(model, 'performance_history_'):
                    fig3 = plot_learning_curves(
                        model.performance_history_,
                        title="LapBoost Learning Curves"
                    )
                    
                    if args.output_dir:
                        fig3.savefig(os.path.join(args.output_dir, "learning_curves.png"), dpi=300)
                
                if not args.output_dir:
                    plt.show()
    else:  # regression
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        print(f"Test MSE: {mse:.4f}")
        print(f"Test R²: {r2:.4f}")
        
        if args.visualize:
            print("\nGenerating visualizations...")
            
            # Plot true vs predicted
            fig1, ax = plt.subplots(figsize=(10, 6))
            ax.scatter(y_test, y_pred, alpha=0.5)
            ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
            ax.set_xlabel("True Values")
            ax.set_ylabel("Predictions")
            ax.set_title("True vs Predicted Values")
            
            if args.output_dir:
                os.makedirs(args.output_dir, exist_ok=True)
                fig1.savefig(os.path.join(args.output_dir, "true_vs_predicted.png"), dpi=300)
            
            # Plot learning curves for iterative model
            if args.iterative and hasattr(model, 'performance_history_'):
                fig2 = plot_learning_curves(
                    model.performance_history_,
                    title="LapBoost Learning Curves"
                )
                
                if args.output_dir:
                    fig2.savefig(os.path.join(args.output_dir, "learning_curves.png"), dpi=300)
            
            if not args.output_dir:
                plt.show()
    
    # Save model if requested
    if args.save_model:
        import pickle
        os.makedirs(os.path.dirname(args.save_model), exist_ok=True)
        with open(args.save_model, 'wb') as f:
            pickle.dump(model, f)
        print(f"\nModel saved to {args.save_model}")
    
    return model


def main():
    """Main CLI entrypoint"""
    parser = argparse.ArgumentParser(
        description="LapBoost CLI for semi-supervised learning",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Data arguments
    parser.add_argument("labeled_data", type=str, help="Path to CSV file containing labeled data")
    parser.add_argument("--unlabeled-data", type=str, help="Path to CSV file containing unlabeled data")
    parser.add_argument("--target-column", "-t", type=str, required=True, help="Name of target column in labeled data")
    parser.add_argument("--test-size", type=float, default=0.2, help="Proportion of labeled data to use for testing")
    
    # Model arguments
    parser.add_argument("--task-type", type=str, choices=["classification", "regression"], 
                        help="Task type (auto-detected if not specified)")
    parser.add_argument("--iterative", action="store_true", help="Use iterative co-training approach")
    parser.add_argument("--k-neighbors", type=int, default=10, help="Number of neighbors for graph construction")
    parser.add_argument("--gamma", type=float, default=0.1, help="Graph regularization weight")
    parser.add_argument("--confidence-threshold", type=float, default=0.7, 
                        help="Confidence threshold for pseudo-labeling")
    parser.add_argument("--confidence-decay", type=float, default=0.95, 
                        help="Decay rate for confidence threshold in iterative training")
    parser.add_argument("--max-iter", type=int, default=3, help="Maximum number of iterations")
    parser.add_argument("--n-estimators", type=int, default=100, help="Number of trees in XGBoost model")
    parser.add_argument("--learning-rate", type=float, default=0.1, help="Learning rate for XGBoost model")
    parser.add_argument("--random-state", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
    
    # Output arguments
    parser.add_argument("--visualize", action="store_true", help="Generate visualizations")
    parser.add_argument("--output-dir", type=str, help="Directory to save visualizations")
    parser.add_argument("--save-model", type=str, help="Path to save trained model")
    
    args = parser.parse_args()
    
    try:
        train_and_evaluate(args)
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        return 1
    
    return 0


if __name__ == "__main__":
    sys.exit(main())
