"""
Legacy training script for backward compatibility.

This module provides a simplified training interface that mimics the original
run.py script for backward compatibility.
"""

import os
import torch
from pathlib import Path
from typing import Dict, Any, Optional

from ..models import STNP
from ..data import get_datasets, load_graph_data
from ..active_learning import ActiveLearner, MeanStd, LatentInfoGain
from ..config.settings import load_config_from_yaml
from ..utils import set_seed, get_device


def run_legacy_training(
    config_path: str = "config.yaml",
    device: Optional[str] = None,
    seed: Optional[int] = None
) -> None:
    """
    Run training using the legacy interface for backward compatibility.
    
    This function mimics the original run.py script behavior.
    
    Args:
        config_path: Path to configuration file
        device: Device to use for training
        seed: Random seed for reproducibility
    """
    # Setup environment
    if seed is not None:
        set_seed(seed)
    
    if device is None:
        device = get_device()
    
    # Set threading environment variables (compatible with PyTorch 2.5+)
    os.environ['OMP_NUM_THREADS'] = '10'
    os.environ['MKL_NUM_THREADS'] = '10'
    os.environ['NUMEXPR_NUM_THREADS'] = '10'
    
    # Keep torch.set_num_threads as fallback for compatibility
    if hasattr(torch, 'set_num_threads'):
        torch.set_num_threads(10)
    
    print(f"Num threads={torch.get_num_threads()}")
    print(f"Device: {device}")
    
    # Load configuration
    config = load_config_from_yaml(config_path)
    
    # Setup paths
    meta_path = Path(config["meta_data"]["metaPath"])
    
    # Get datasets
    train_dataset, val_dataset, pool_dataset = get_datasets(
        meta_path=config["meta_data"]["metaPath"],
        data_path=config["meta_data"]["dataPath"],
        src_path=config["meta_data"]["srcPath"],
        x_col_names=config["data"]["x_col_names"],
        frac_pops_names=config["data"]["frac_pops_names"],
        initial_col_names=config["data"]["initial_col_names"],
        seq_len=config["model"]["seq_len"],
        num_nodes=config["model"]["num_nodes"],
        population_csv_path=config["meta_data"]["population_csv_path"],
        population_scaler=config["model"]["POPULATION_SCALER"]
    )
    
    # Load graph data
    edge_index, edge_weight = load_graph_data(config["meta_data"]["metaPath"])
    
    # Create model
    model = STNP(
        x_dim=config["model"]["x_dim"],
        xt_dim=config["model"]["xt_dim"],
        y_dim=config["model"]["y_dim"],
        z_dim=config["model"]["z_dim"],
        r_dim=config["model"]["r_dim"],
        seq_len=config["model"]["seq_len"],
        num_nodes=config["model"]["num_nodes"],
        in_channels=config["model"]["in_channels"],
        out_channels=config["model"]["out_channels"],
        embed_out_dim=config["model"]["embed_out_dim"],
        max_diffusion_step=config["model"]["max_diffusion_step"],
        encoder_num_rnn=config["model"]["encoder_num_rnn"],
        decoder_num_rnn=config["model"]["decoder_num_rnn"],
        decoder_hidden_dims=config["model"]["decoder_hidden_dims"],
        num_comp=config["model"]["NUM_COMP"],
        context_percentage=config["model"]["context_percentage"],
        lr=config["train"]["lr"],
        lr_encoder=config["train"]["lr_encoder"],
        lr_decoder=config["train"]["lr_decoder"],
        lr_milestones=config["train"]["lr_milestones"],
        lr_gamma=config["train"]["lr_gamma"],
        edge_index=edge_index,
        edge_weight=edge_weight
    )
    
    # Create acquisition function
    acquisition = MeanStd(
        output_dim=config["model"]["y_dim"],
        acquisition_size=config["mstd"]["acquisition_size"],
        pool_loader_batch_size=config["mstd"]["pool_loader_batch_size"],
        acquisition_pool_fraction=config["mstd"]["acquisition_pool_fraction"],
        num_workers=4,
        device=device
    )
    
    # Create active learner
    learner = ActiveLearner(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        pool_dataset=pool_dataset,
        acquisition_function=acquisition,
        config={
            "initial_samples": config["active_learner"]["initial_train_size"],
            "samples_per_iteration": config["mstd"]["acquisition_size"],
            "max_iterations": config["active_learner"]["max_iter"],
            "min_pool_size": 0
        },
        training_config={
            "max_epochs": config["train"]["max_epochs"],
            "batch_size": config["train"]["train_batch_size"],
            "lr": config["train"]["lr"],
            "patience": 30,
            "min_delta": 0.001
        },
        device=device
    )
    
    # Initialize and run training
    learner.initialize_data()
    learner.learn(config["active_learner"]["max_iter"])
    
    print("Legacy training completed successfully!")


def create_legacy_run_script(
    config_path: str = "config.yaml",
    output_path: str = "run_legacy.py"
) -> None:
    """
    Create a legacy run script for backward compatibility.
    
    Args:
        config_path: Path to configuration file
        output_path: Path to save the legacy script
    """
    script_content = f'''#!/usr/bin/env python3
"""
Legacy training script for GLEAM-AI.
This script provides backward compatibility with the original run.py.
"""

import sys
from pathlib import Path

# Add the current directory to Python path
sys.path.insert(0, str(Path(__file__).parent))

from gleam_ai.training.legacy_trainer import run_legacy_training

def main():
    """Main function for legacy training."""
    run_legacy_training(
        config_path="{config_path}",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )

if __name__ == "__main__":
    main()
'''
    
    with open(output_path, 'w') as f:
        f.write(script_content)
    
    print(f"Legacy run script created at: {output_path}")


# For backward compatibility, provide the same interface as the original run.py
def main():
    """Main function for backward compatibility."""
    run_legacy_training()


if __name__ == "__main__":
    main()
