"""Training script for time series predictors."""

import datetime
import os
import sys
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List

import pandas as pd
import structlog

from residual_chronos.Regressor import CrossSectionalRegressor, TimeSeriesRegressor, CovariateRegressor
from residual_chronos.train.train_utils import (
    load_config,
    load_m5_data,
    setup_logging,
    prepare_model_path,
    load_time_series_data,
    evaluate_forecast,
    save_model_config,
    plot_forecast_vs_actual
)

logger = structlog.get_logger()


def train_time_series_model(
    model_config: Dict[str, Any],
    global_config: Dict[str, Any],
    train_data,
    test_data,
    base_output_dir: Path
) -> Dict[str, Any]:
    """Train a single time series model with its own predictor.
    
    Args:
        model_config: Configuration for this specific model.
        global_config: Global configuration settings.
        train_data: Training data.
        test_data: Test data.
        base_output_dir: Base directory for output.
        
    Returns:
        Dictionary with training results.
        
    Raises:
        RuntimeError: If training fails.
    """
    model_name = model_config['name']
    
    try:
        # Create model-specific output directory
        model_path = prepare_model_path(global_config, model_name)
        
        # Setup logging for this model
        log_dir = model_path / "logs"
        setup_logging(log_dir)
        
        # Save model configuration
        save_model_config(model_config, model_path)
        
        # Create predictor with model-specific settings
        predictor = TimeSeriesRegressor(
            model_name = model_name,
            prediction_length=global_config['data']['prediction_length'],
            target=model_config.get('target_column', 'target'),
            known_covariates_names=model_config.get('known_covariates', None),
            path=str(model_path),
        )
        
        # Log training start
        logger.info(
            "Training model",
            model=model_name,
            hyperparameters=model_config.get('hyperparameters', {})
        )
        
        # Get global settings
        time_limit = global_config['global_settings'].get('time_limit', None)
        enable_ensemble = global_config['global_settings'].get('enable_ensemble', False)
        verbosity=global_config['global_settings'].get('verbosity', 0)
        num_val_windows = global_config['global_settings'].get('num_val_windows', 1)
        
        # Train predictor with model-specific hyperparameters
        predictor.fit(
            train_data,
            hyperparameters=model_config.get('hyperparameters', {}),
            enable_ensemble=enable_ensemble,
            time_limit=time_limit,
            verbosity=verbosity,
            num_val_windows=num_val_windows,
        )
        
        # Generate forecast
        logger.info("Generating forecast for evaluation", model=model_name)
        forecast = predictor.predict(train_data, known_covariates=test_data)
        
        # Evaluate results
        model_config_with_global = {**global_config, 'target_column': model_config.get('target_column', 'target')}
        results = evaluate_forecast(test_data, forecast, model_config_with_global)
        
        # Save evaluation results
        with open(model_path / "evaluation.json", 'w') as f:
            json.dump(results, f, indent=2)
        
        plot_path = model_path / f"{model_name}_forecast.png"
        prediction_length = global_config['data']['prediction_length']
        plot_forecast_vs_actual(
            test_data=test_data.slice_by_timestep(-prediction_length*2, None),
            forecast=forecast,
            metrics=results,
            model_name=model_name,
            save_path=plot_path,
            target_column=model_config.get('target_column', 'target')
        )

        logger.info(
            "Model training completed successfully",
            model=model_name,
            path=str(model_path),
            metrics=results
        )
            
        return {
            "model": model_name,
            "path": str(model_path),
            "metrics": results
        }
        
    except Exception as e:
        logger.error(
            "Error during model training",
            model=model_name,
            error=str(e),
            error_type=type(e).__name__
        )
        raise RuntimeError(f"Training failed for model {model_name}: {str(e)}") from e


def train_tabular_model(
    model_config: Dict[str, Any],
    global_config: Dict[str, Any],
    train_data,
    test_data,
    base_output_dir: Path
) -> Dict[str, Any]:
    """Train a tabular model using CovariateRegressor.
    
    Args:
        model_config: Configuration for this specific model.
        global_config: Global configuration settings.
        train_data: Training data.
        test_data: Test data.
        base_output_dir: Base directory for output.
        
    Returns:
        Dictionary with training results.
        
    Raises:
        RuntimeError: If training fails.
    """
    model_name = model_config['name']
    
    try:
        # Create model-specific output directory
        model_path = prepare_model_path(global_config, model_name)
        
        # Setup logging for this model
        log_dir = model_path / "logs"
        setup_logging(log_dir)
        
        # Save model configuration
        save_model_config(model_config, model_path)
        
        # Get hyperparameters
        hyperparameters = model_config.get('hyperparameters', {})
        
        # Create covariate metadata from known covariates
        from autogluon.timeseries.utils.features import CovariateMetadata
        covariate_metadata = CovariateMetadata(
            static_features_cat=model_config.get('static_features_cat', []), 
            static_features_real=model_config.get('static_features_real', []), 
            known_covariates_real=model_config.get('known_covariates_real', []), 
            known_covariates_cat=model_config.get('known_covariates_cat', []), 
            past_covariates_real=model_config.get('past_covariates_real', []), 
            past_covariates_cat=model_config.get('past_covariates_cat', [])
        )
        logger.info(f"Covariate metadata: {covariate_metadata}")
        
        # Log training start
        logger.info(
            "Training tabular model",
            model=model_name,
            hyperparameters=hyperparameters
        )
        
        # Get global settings
        time_limit = global_config['global_settings'].get('time_limit', None)
        
        # Create and train covariate regressor
        regressor = CovariateRegressor(
            model_name=model_name,
            model_hyperparameters=hyperparameters,
            target=model_config.get('target_column', 'target'),
            covariate_metadata=covariate_metadata,
            include_static_features=model_config.get('include_static_features', True),
            include_item_id=model_config.get('include_item_id', True),
            eval_metric=model_config.get('eval_metric', 'mean_absolute_error'),
            validation_fraction=model_config.get('validation_fraction', 0.1),
            fit_time_fraction=model_config.get('fit_time_fraction', 0.5),
        )
        
        # Fit the regressor
        regressor.fit(train_data, time_limit=time_limit)
        
        # Test prediction to verify the model works
        if not regressor.disabled:
            # Get predictions
            prediction_values = regressor._predict(test_data, static_features=test_data.static_features)
            
            # Convert predictions to a TimeSeriesDataFrame
            target_column = model_config.get('target_column', 'target')
            prediction_length = global_config['data']['prediction_length']
            
            # Create a copy of test_data and replace the target column with predictions
            forecast_df = test_data.copy()
            
            # Create a forecast dataframe with 'mean' column (required by evaluate_forecast)
            forecast_df = forecast_df.assign(mean=prediction_values)
            forecast_df = forecast_df[['mean']]
            
            # Slice to get only the prediction length part
            forecast = forecast_df.slice_by_timestep(-prediction_length, None)            
            
            # Evaluate using the same function as for time series models
            model_config_with_global = {**global_config, 'target_column': target_column}
            results = evaluate_forecast(test_data, forecast, model_config_with_global)

            plot_path = model_path / f"{model_name}_forecast.png"
            prediction_length = global_config['data']['prediction_length']
            plot_forecast_vs_actual(
                test_data=test_data.slice_by_timestep(-prediction_length*2, None),
                forecast=forecast,
                metrics=results,
                model_name=model_name,
                save_path=plot_path,
                target_column=target_column
            )

        else:
            logger.warning(
                "Regressor was disabled during training",
                model=model_name
            )
            results = {"error": "model_disabled"}
        
        # Save the trained regressor
        regressor.save(model_path)
        
        # Save evaluation results
        with open(model_path / "evaluation.json", 'w') as f:
            json.dump(results, f, indent=2)

        logger.info(
            "Tabular model training completed successfully",
            model=model_name,
            path=str(model_path),
            metrics=results
        )
        
        return {
            "model": model_name,
            "path": str(model_path),
            "metrics": results
        }
        
    except Exception as e:
        logger.exception(
            "Error during tabular model training",
            model=model_name
        )
        raise RuntimeError(f"Training failed for model {model_name}: {str(e)}") from e


def train_time_series_models(config_path: str) -> List[Dict[str, Any]]:
    """Train multiple time series models from configuration.
    
    Args:
        config_path: Path to configuration file.
        
    Returns:
        List of training results for each model.
        
    Raises:
        ValueError: If configuration is invalid.
        RuntimeError: If training fails.
    """
    # Load configuration
    config = load_config(config_path)
    
    # Check for required configuration sections
    required_sections = ['data', 'models', 'base_output_dir', 'evaluation']
    for section in required_sections:
        if section not in config:
            raise ValueError(f"Missing required configuration section: {section}")
    
    # Load and prepare data once for all models
    if config['data']['source'] in ['m5', 'traffic', 'exchange', 'epf', 'illness', 'sales_synthetic', 'nonlinear_sales_synthetic', 'electricity_synthetic']:
        train_data, test_data = load_m5_data(config)
    else:
        train_data, test_data = load_time_series_data(config)
    
    # Train each model
    results = []
    for model_config in config['models']:
        try:
            if model_config['name'] in CrossSectionalRegressor.AVAILABLE_TS_MODELS:
                model_result = train_time_series_model(
                    model_config=model_config,
                    global_config=config,
                    train_data=train_data,
                    test_data=test_data,
                    base_output_dir=Path(config['base_output_dir'])
                )
            elif model_config['name'] in CrossSectionalRegressor.AVAILABLE_TABULAR_MODELS:
                model_result = train_tabular_model(
                    model_config=model_config,
                    global_config=config,
                    train_data=train_data,
                    test_data=test_data,
                    base_output_dir=Path(config['base_output_dir'])
                )
            else:
                raise ValueError(f"Invalid model name: {model_config['name']}")
            results.append(model_result)
        except Exception as e:
            logger.error(f"Failed to train model {model_config['name']}: {str(e)}", exc_info=True)
            # Continue with other models even if one fails
    
    # Save overall results summary
    summary_path = Path(config['base_output_dir']) / f"training_summary_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(summary_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(
        "All model training completed",
        num_models_trained=len(results),
        summary_path=str(summary_path)
    )
    
    return results


def main():
    """Main entry point for training script.
    
    Args:
        None
        
    Returns:
        None
    """
    parser = argparse.ArgumentParser(
        description="Train multiple time series models with AutoGluon"
    )
    
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to YAML or JSON configuration file"
    )
    
    args = parser.parse_args()

    try:
        train_time_series_models(args.config)
    except Exception as e:
        logger.error(f"Training failed: {str(e)}")
        sys.exit(1)


if __name__ == "__main__":
    main() 