"""
Training utilities for GLEAM-AI.

This module contains utility functions for training setup, configuration,
and experiment management.
"""

import torch
import yaml
import json
from pathlib import Path
from typing import Dict, Any, Optional, Union, Tuple
import logging
import numpy as np
import pandas as pd

from ..config.settings import (
    ModelConfig, TrainingConfig, DataConfig, ActiveLearningConfig,
    load_config_from_yaml
)
from ..utils import set_seed, get_device, create_experiment_directory

logger = logging.getLogger(__name__)


def load_training_config(
    config_path: Union[str, Path],
    config_type: str = "yaml"
) -> Tuple[ModelConfig, TrainingConfig, DataConfig, ActiveLearningConfig]:
    """
    Load training configuration from file.
    
    Args:
        config_path: Path to configuration file
        config_type: Type of configuration file ("yaml" or "json")
        
    Returns:
        Tuple of configuration objects
    """
    config_path = Path(config_path)
    
    if not config_path.exists():
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    if config_type == "yaml":
        with open(config_path, 'r') as f:
            config_dict = yaml.safe_load(f)
    elif config_type == "json":
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
    else:
        raise ValueError(f"Unsupported config type: {config_type}")
    
    # Create configuration objects
    model_config = ModelConfig.from_dict(config_dict.get("model", {}))
    training_config = TrainingConfig.from_dict(config_dict.get("train", {}))
    data_config = DataConfig.from_dict(config_dict.get("data", {}))
    active_learning_config = ActiveLearningConfig.from_dict(config_dict.get("active_learner", {}))
    
    logger.info(f"Configuration loaded from: {config_path}")
    return model_config, training_config, data_config, active_learning_config


def setup_training_environment(
    seed: Optional[int] = None,
    device: Optional[str] = None,
    num_threads: Optional[int] = None
) -> str:
    """
    Setup the training environment.
    
    Args:
        seed: Random seed for reproducibility
        device: Device to use for training
        num_threads: Number of threads for PyTorch
        
    Returns:
        Device string
    """
    import os
    
    # Set random seed
    if seed is not None:
        set_seed(seed)
        logger.info(f"Random seed set to: {seed}")
    
    # Set number of threads (compatible with PyTorch 2.5+)
    if num_threads is not None:
        # Set environment variables for threading (PyTorch 2.5+ compatible)
        os.environ['OMP_NUM_THREADS'] = str(num_threads)
        os.environ['MKL_NUM_THREADS'] = str(num_threads)
        os.environ['NUMEXPR_NUM_THREADS'] = str(num_threads)
        
        # Keep torch.set_num_threads as fallback for compatibility
        if hasattr(torch, 'set_num_threads'):
            torch.set_num_threads(num_threads)
        
        logger.info(f"PyTorch threads set to: {num_threads}")
    
    # Determine device
    if device is None:
        device = get_device()
    else:
        device = device.lower()
        if device == "cuda" and not torch.cuda.is_available():
            logger.warning("CUDA requested but not available, falling back to CPU")
            device = "cpu"
        elif device == "mps" and not hasattr(torch.backends, 'mps'):
            logger.warning("MPS requested but not available, falling back to CPU")
            device = "cpu"
    
    logger.info(f"Training environment setup complete. Device: {device}")
    return device


def create_experiment_setup(
    base_dir: Union[str, Path],
    experiment_name: str,
    config: Dict[str, Any],
    create_subdirs: bool = True
) -> Path:
    """
    Create experiment directory structure and save configuration.
    
    Args:
        base_dir: Base directory for experiments
        experiment_name: Name of the experiment
        config: Configuration dictionary
        create_subdirs: Whether to create subdirectories
        
    Returns:
        Path to experiment directory
    """
    # Create experiment directory
    exp_dir = create_experiment_directory(base_dir, experiment_name, create_subdirs)
    
    # Save configuration
    config_path = exp_dir / "configs" / "config.yaml"
    with open(config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, indent=2)
    
    # Save configuration as JSON as well
    config_json_path = exp_dir / "configs" / "config.json"
    with open(config_json_path, 'w') as f:
        json.dump(config, f, indent=2)
    
    logger.info(f"Experiment setup created at: {exp_dir}")
    logger.info(f"Configuration saved to: {config_path}")
    
    return exp_dir


def validate_training_config(
    model_config: ModelConfig,
    training_config: TrainingConfig,
    data_config: DataConfig,
    active_learning_config: Optional[ActiveLearningConfig] = None
) -> bool:
    """
    Validate training configuration.
    
    Args:
        model_config: Model configuration
        training_config: Training configuration
        data_config: Data configuration
        active_learning_config: Active learning configuration (optional)
        
    Returns:
        True if configuration is valid, False otherwise
    """
    errors = []
    
    # Validate model configuration
    if model_config.seq_len <= 0:
        errors.append("seq_len must be positive")
    
    if model_config.num_nodes <= 0:
        errors.append("num_nodes must be positive")
    
    if model_config.x_dim <= 0:
        errors.append("x_dim must be positive")
    
    if model_config.y_dim <= 0:
        errors.append("y_dim must be positive")
    
    # Validate training configuration
    if training_config.max_epochs <= 0:
        errors.append("max_epochs must be positive")
    
    if training_config.lr <= 0:
        errors.append("lr must be positive")
    
    if training_config.batch_size <= 0:
        errors.append("batch_size must be positive")
    
    # Validate data configuration
    if not data_config.x_col_names:
        errors.append("x_col_names cannot be empty")
    
    if not data_config.frac_pops_names:
        errors.append("frac_pops_names cannot be empty")
    
    # Validate active learning configuration
    if active_learning_config is not None:
        if active_learning_config.initial_samples <= 0:
            errors.append("initial_samples must be positive")
        
        if active_learning_config.samples_per_iteration <= 0:
            errors.append("samples_per_iteration must be positive")
        
        if active_learning_config.max_iterations <= 0:
            errors.append("max_iterations must be positive")
    
    if errors:
        logger.error("Configuration validation failed:")
        for error in errors:
            logger.error(f"  - {error}")
        return False
    
    logger.info("Configuration validation passed")
    return True


def setup_logging(
    log_dir: Union[str, Path],
    experiment_name: str,
    level: str = "INFO"
) -> None:
    """
    Setup logging for training.
    
    Args:
        log_dir: Directory for log files
        experiment_name: Name of the experiment
        level: Logging level
    """
    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    
    # Setup logging configuration
    log_file = log_dir / f"{experiment_name}.log"
    
    logging.basicConfig(
        level=getattr(logging, level.upper()),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    
    logger.info(f"Logging setup complete. Log file: {log_file}")


def save_training_summary(
    output_dir: Union[str, Path],
    experiment_name: str,
    config: Dict[str, Any],
    results: Dict[str, Any],
    metrics: Optional[Dict[str, float]] = None
) -> None:
    """
    Save training summary and results.
    
    Args:
        output_dir: Output directory
        experiment_name: Name of the experiment
        config: Configuration dictionary
        results: Training results
        metrics: Additional metrics
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create summary
    summary = {
        "experiment_name": experiment_name,
        "config": config,
        "results": results,
        "metrics": metrics or {},
        "timestamp": pd.Timestamp.now().isoformat()
    }
    
    # Save summary
    summary_path = output_dir / f"{experiment_name}_summary.json"
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2, default=str)
    
    logger.info(f"Training summary saved to: {summary_path}")


def load_training_results(
    results_path: Union[str, Path]
) -> Dict[str, Any]:
    """
    Load training results from file.
    
    Args:
        results_path: Path to results file
        
    Returns:
        Dictionary containing training results
    """
    results_path = Path(results_path)
    
    if not results_path.exists():
        raise FileNotFoundError(f"Results file not found: {results_path}")
    
    with open(results_path, 'r') as f:
        results = json.load(f)
    
    logger.info(f"Training results loaded from: {results_path}")
    return results


def compare_experiments(
    results_dir: Union[str, Path],
    experiment_names: list,
    metric: str = "val_loss"
) -> pd.DataFrame:
    """
    Compare multiple experiments.
    
    Args:
        results_dir: Directory containing experiment results
        experiment_names: List of experiment names to compare
        metric: Metric to compare
        
    Returns:
        DataFrame with comparison results
    """
    results_dir = Path(results_dir)
    comparison_data = []
    
    for exp_name in experiment_names:
        summary_path = results_dir / f"{exp_name}_summary.json"
        
        if summary_path.exists():
            with open(summary_path, 'r') as f:
                summary = json.load(f)
            
            comparison_data.append({
                "experiment": exp_name,
                "metric": summary.get("metrics", {}).get(metric, None),
                "config": summary.get("config", {}),
                "timestamp": summary.get("timestamp", None)
            })
        else:
            logger.warning(f"Results not found for experiment: {exp_name}")
    
    if not comparison_data:
        logger.warning("No experiment results found for comparison")
        return pd.DataFrame()
    
    df = pd.DataFrame(comparison_data)
    df = df.sort_values("metric", ascending=True)
    
    logger.info(f"Experiment comparison completed. Best experiment: {df.iloc[0]['experiment']}")
    return df


def create_training_script(
    config_path: Union[str, Path],
    output_path: Union[str, Path],
    script_type: str = "standard"
) -> None:
    """
    Create a training script from configuration.
    
    Args:
        config_path: Path to configuration file
        output_path: Path to save the training script
        script_type: Type of script ("standard" or "active_learning")
    """
    config_path = Path(config_path)
    output_path = Path(output_path)
    
    if not config_path.exists():
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    # Load configuration
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Create script content
    if script_type == "standard":
        script_content = _create_standard_training_script(config)
    elif script_type == "active_learning":
        script_content = _create_active_learning_script(config)
    else:
        raise ValueError(f"Unsupported script type: {script_type}")
    
    # Save script
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        f.write(script_content)
    
    logger.info(f"Training script created at: {output_path}")


def _create_standard_training_script(config: Dict[str, Any]) -> str:
    """Create a standard training script."""
    return f'''#!/usr/bin/env python3
"""
Auto-generated training script for GLEAM-AI.
"""

import sys
from pathlib import Path

# Add the current directory to Python path
sys.path.insert(0, str(Path(__file__).parent))

from gleam_ai.training import GLEAMTrainer
from gleam_ai.training.utils import (
    load_training_config, setup_training_environment, 
    create_experiment_setup, validate_training_config
)

def main():
    """Main training function."""
    # Load configuration
    model_config, training_config, data_config, active_learning_config = load_training_config(
        "{config.get('config_path', 'config.yaml')}"
    )
    
    # Setup environment
    device = setup_training_environment(
        seed={config.get('seed', 'None')},
        device="{config.get('device', 'None')}",
        num_threads={config.get('num_threads', 'None')}
    )
    
    # Validate configuration
    if not validate_training_config(model_config, training_config, data_config):
        raise ValueError("Configuration validation failed")
    
    # Create experiment setup
    exp_dir = create_experiment_setup(
        base_dir="{config.get('output_dir', './experiments')}",
        experiment_name="{config.get('experiment_name', 'gleam_experiment')}",
        config={config}
    )
    
    # Create trainer
    trainer = GLEAMTrainer(
        model_config=model_config,
        training_config=training_config,
        data_config=data_config,
        device=device
    )
    
    # Setup data
    trainer.setup_data(
        meta_path="{config.get('meta_path', './meta_data')}",
        data_path="{config.get('data_path', './data')}",
        src_path="{config.get('src_path', './src')}",
        population_csv_path="{config.get('population_csv_path', './meta_data/populations.csv')}"
    )
    
    # Setup model
    trainer.setup_model(meta_path="{config.get('meta_path', './meta_data')}")
    
    # Train model
    results = trainer.train(
        output_dir=exp_dir,
        experiment_name="{config.get('experiment_name', 'gleam_experiment')}"
    )
    
    print("Training completed successfully!")
    print(f"Best model saved at: {{results['best_model_path']}}")

if __name__ == "__main__":
    main()
'''


def _create_active_learning_script(config: Dict[str, Any]) -> str:
    """Create an active learning training script."""
    return f'''#!/usr/bin/env python3
"""
Auto-generated active learning training script for GLEAM-AI.
"""

import sys
from pathlib import Path

# Add the current directory to Python path
sys.path.insert(0, str(Path(__file__).parent))

from gleam_ai.training import GLEAMTrainer
from gleam_ai.training.utils import (
    load_training_config, setup_training_environment, 
    create_experiment_setup, validate_training_config
)

def main():
    """Main active learning training function."""
    # Load configuration
    model_config, training_config, data_config, active_learning_config = load_training_config(
        "{config.get('config_path', 'config.yaml')}"
    )
    
    # Setup environment
    device = setup_training_environment(
        seed={config.get('seed', 'None')},
        device="{config.get('device', 'None')}",
        num_threads={config.get('num_threads', 'None')}
    )
    
    # Validate configuration
    if not validate_training_config(model_config, training_config, data_config, active_learning_config):
        raise ValueError("Configuration validation failed")
    
    # Create experiment setup
    exp_dir = create_experiment_setup(
        base_dir="{config.get('output_dir', './experiments')}",
        experiment_name="{config.get('experiment_name', 'gleam_active_learning')}",
        config={config}
    )
    
    # Create trainer
    trainer = GLEAMTrainer(
        model_config=model_config,
        training_config=training_config,
        data_config=data_config,
        active_learning_config=active_learning_config,
        device=device
    )
    
    # Setup data
    trainer.setup_data(
        meta_path="{config.get('meta_path', './meta_data')}",
        data_path="{config.get('data_path', './data')}",
        src_path="{config.get('src_path', './src')}",
        population_csv_path="{config.get('population_csv_path', './meta_data/populations.csv')}"
    )
    
    # Setup model
    trainer.setup_model(meta_path="{config.get('meta_path', './meta_data')}")
    
    # Setup active learning
    trainer.setup_active_learning(
        acquisition_type="{config.get('acquisition_type', 'mean_std')}"
    )
    
    # Train with active learning
    results = trainer.train_with_active_learning(
        output_dir=exp_dir,
        experiment_name="{config.get('experiment_name', 'gleam_active_learning')}"
    )
    
    print("Active learning training completed successfully!")
    print(f"Training history: {{len(results['training_history'])}} iterations")

if __name__ == "__main__":
    main()
'''
