"""Utility functions for training models."""

import os
import yaml
import json
import datetime
import logging
from pathlib import Path
from typing import Dict, Any, Optional, Union, List, Tuple
import numpy as np

import pandas as pd
import structlog
from autogluon.timeseries import TimeSeriesDataFrame
from gift_eval.wrapper import GiftEvalWrapper, M5Wrapper, NonLinearSaleSyntheticWrapper, SaleSyntheticWrapper, TrafficWrapper, ExchangeWrapper, EPFWrapper, IllnessWrapper, ElectricityWrapper

# Setup structured logging
logger = structlog.get_logger()

def setup_logging(log_dir: Path) -> None:
    """Set up structured logging configuration.
    
    Args:
        log_dir: Directory to save log files.
        
    Returns:
        None
    """
    os.makedirs(log_dir, exist_ok=True)
    
    # Configure structlog
    structlog.configure(
        processors=[
            structlog.contextvars.merge_contextvars,
            structlog.processors.add_log_level,
            structlog.processors.StackInfoRenderer(),
            structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S"),
            structlog.processors.format_exc_info,  # Add exception formatting
            structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
        ],
        logger_factory=structlog.stdlib.LoggerFactory(),
        wrapper_class=structlog.stdlib.BoundLogger,
        cache_logger_on_first_use=True,
    )
    
    # Set up both file and console handlers for stdlib logger
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    
    # Clear existing handlers
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    
    # Add file handler
    file_handler = logging.FileHandler(log_dir / "training.log")
    file_formatter = structlog.stdlib.ProcessorFormatter(
        processor=structlog.processors.JSONRenderer(),
    )
    file_handler.setFormatter(file_formatter)
    root_logger.addHandler(file_handler)
    
    # Add console handler
    console_handler = logging.StreamHandler()
    console_formatter = structlog.stdlib.ProcessorFormatter(
        processor=structlog.dev.ConsoleRenderer(colors=True),
    )
    console_handler.setFormatter(console_formatter)
    root_logger.addHandler(console_handler)
    
    # Make sure we capture all errors
    root_logger.setLevel(logging.INFO)


def load_config(config_path: Union[str, Path]) -> Dict[str, Any]:
    """Load configuration from YAML or JSON file.
    
    Args:
        config_path: Path to configuration file.
        
    Returns:
        Configuration as dictionary.
        
    Raises:
        ValueError: If file format is not supported or file doesn't exist.
    """
    config_path = Path(config_path)
    
    if not config_path.exists():
        raise ValueError(f"Configuration file not found: {config_path}")
    
    if config_path.suffix.lower() == '.yaml' or config_path.suffix.lower() == '.yml':
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    elif config_path.suffix.lower() == '.json':
        with open(config_path, 'r') as f:
            config = json.load(f)
    else:
        raise ValueError(f"Unsupported configuration format: {config_path.suffix}")
    
    logger.info("Configuration loaded", config_path=str(config_path))
    return config


def prepare_model_path(config: Dict[str, Any], model_name: str) -> Path:
    """Prepare output directory for a specific model.
    
    Args:
        config: Configuration dictionary.
        model_name: Name of the model.
        
    Returns:
        Path to model output directory.
    """
    base_output_dir = Path(config['base_output_dir'])
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create path with model name and timestamp
    model_path = base_output_dir / f"{model_name}_{timestamp}"
    os.makedirs(model_path, exist_ok=True)
    
    # Create log directory
    log_dir = model_path / "logs"
    os.makedirs(log_dir, exist_ok=True)
    
    logger.info("Model directory prepared", model=model_name, path=str(model_path))
    return model_path


def apply_scalers(train_data, test_data, config):
    """Apply scalers to train and test data according to configuration.
    
    Args:
        train_data: Training data as TimeSeriesDataFrame.
        test_data: Test data as TimeSeriesDataFrame.
        config: Configuration dictionary.
        
    Returns:
        Tuple of (train_data, test_data) with scalers applied.
        
    Raises:
        ValueError: If an unsupported scaler is specified.
    """
    from autogluon.timeseries.transforms.target_scaler import (
        LocalStandardScaler, 
        LocalMeanAbsScaler,
        LocalMinMaxScaler,
        LocalRobustScaler
    )
    
    preprocessing_config = config['data'].get('preprocessing', {})
    
    # Apply target scaler if specified
    target_scaler_name = preprocessing_config.get('target_scaler')
    target_column = config['data'].get('target_column', 'target')
    
    if target_scaler_name:
        logger.info(f"Applying {target_scaler_name} to target column {target_column}")
        
        if target_scaler_name == "LocalStandardScaler":
            scaler = LocalStandardScaler(target=target_column)
        elif target_scaler_name == "LocalMeanAbsScaler":
            scaler = LocalMeanAbsScaler(target=target_column)
        elif target_scaler_name == "LocalMinMaxScaler":
            scaler = LocalMinMaxScaler(target=target_column)
        elif target_scaler_name == "LocalRobustScaler":
            scaler = LocalRobustScaler(target=target_column)
        else:
            raise ValueError(f"Unsupported target scaler: {target_scaler_name}")
        
        # Fit on train and transform both train and test
        train_data = scaler.fit_transform(train_data)
        test_data = scaler.transform(test_data)
        
        # Store scaler in the data for later use
        train_data._scaler = scaler
    
    # Apply feature scalers if specified
    feature_scalers = preprocessing_config.get('feature_scalers', [])
    
    for feature_config in feature_scalers:
        feature_name = feature_config['name']
        scaler_name = feature_config.get('scaler')
        
        if not scaler_name:
            continue
            
        logger.info(f"Applying {scaler_name} to feature {feature_name}")
        
        # Skip if feature doesn't exist
        if feature_name not in train_data.columns:
            logger.warning(f"Feature {feature_name} not found in data, skipping scaling")
            continue
            
        if scaler_name == "StandardScaler":
            from sklearn.preprocessing import StandardScaler as SklearnStandardScaler
            scaler = SklearnStandardScaler()
        elif scaler_name == "MinMaxScaler":
            from sklearn.preprocessing import MinMaxScaler
            scaler = MinMaxScaler()
        else:
            raise ValueError(f"Unsupported feature scaler: {scaler_name}")
        
        # Apply scaler to specified column
        train_values = train_data[feature_name].values.reshape(-1, 1)
        train_data[feature_name] = scaler.fit_transform(train_values).flatten()
        
        if feature_name in test_data.columns:
            test_values = test_data[feature_name].values.reshape(-1, 1)
            test_data[feature_name] = scaler.transform(test_values).flatten()
    
    return train_data, test_data


def load_time_series_data_v0(config: Dict[str, Any]) -> Tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]:
    """Load and prepare time series data based on configuration.
    
    Args:
        config: Configuration dictionary.
        
    Returns:
        Tuple of (train_data, test_data) as TimeSeriesDataFrame objects.
        
    Raises:
        ValueError: If data loading fails.
    """
    data_config = config['data']
    source = data_config['source']
    
    try:
        # Load data
        logger.info("Loading time series data", source=source)
        data = TimeSeriesDataFrame.from_path(source)
        
        # Split data for training and testing
        prediction_length = data_config['prediction_length']
        train_data, test_data = data.train_test_split(prediction_length=prediction_length)
        
        # Verify we have enough data
        min_points = len(train_data) // len(train_data.item_ids)
        if min_points < 25:
            logger.warning(
                "Limited training data available",
                min_points_per_item=min_points,
                recommended_minimum=25
            )
        
        # Apply scalers if configured
        if 'preprocessing' in data_config:
            logger.info("Applying data preprocessing and scaling")
            train_data, test_data = apply_scalers(train_data, test_data, config)
        
        logger.info(
            "Data preparation complete",
            train_size=len(train_data),
            test_size=len(test_data),
            num_items=len(train_data.item_ids)
        )
        
        return train_data, test_data
    
    except Exception as e:
        logger.exception("Failed to load and preprocess time series data")
        raise ValueError(f"Data loading failed: {str(e)}") from e


def load_time_series_data(config: Dict[str, Any]) -> Tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]:
    """Load and prepare time series data based on configuration.
    
    Args:
        config: Configuration dictionary.
        
    Returns:
        Tuple of (train_data, test_data) as TimeSeriesDataFrame objects.
        
    Raises:
        ValueError: If data loading fails.
    """
    data_config = config['data']
    source = data_config['source']
    term = data_config['term']
    to_univariate = data_config['to_univariate']
    target_idx = data_config.get('target_idx', 0)
    slice_start = data_config.get('slice_start', -5000)
    add_date_features = data_config.get('add_date_features', False)

    try:
        # Load data
        logger.info("Loading time series data", source=source)
        gift_wrapper = GiftEvalWrapper(name=source, term=term, to_univariate=to_univariate, target_idx=target_idx)
        
        # Split data for training and testing
        prediction_length = gift_wrapper.dataset.prediction_length
        train_data = gift_wrapper.get_train_data(add_date_features=add_date_features)
        test_data = gift_wrapper.get_validation_data(add_date_features=add_date_features)
        train_data = train_data.slice_by_timestep(slice_start, None)
        test_data = test_data.slice_by_timestep(slice_start, None)
        logger.info(f"length of train_data: {len(train_data)}")
        logger.info(f"length of test_data: {len(test_data)}")

        # Apply scalers if configured
        if 'preprocessing' in data_config:
            logger.info("Applying data preprocessing and scaling")
            train_data, test_data = apply_scalers(train_data, test_data, config)
        
        logger.info(
            "Data preparation complete",
            train_size=len(train_data),
            test_size=len(test_data),
            num_items=len(train_data.item_ids)
        )
        
        return train_data, test_data
    
    except Exception as e:
        logger.exception("Failed to load and preprocess time series data")
        raise ValueError(f"Data loading failed: {str(e)}") from e


def load_m5_data(config: dict) -> TimeSeriesDataFrame:
    """Load and prepare M5 time series data for forecasting.
    
    Processes Kaggle's M5 Forecasting competition data into an AutoGluon TimeSeriesDataFrame
    format. Combines sales history with calendar events and pricing information.
    
    Args:
        config: Configuration dictionary containing paths and processing options.
            Required keys:
                - path: Base path to the M5 dataset files.
            Optional keys:
                - features: List of feature columns to include (default: all available).
                - series_id_type: How to construct series_id ('item_store' or 'item_dept_store').
                - max_history: Maximum number of timesteps to include (default: all).
    
    Returns:
        TimeSeriesDataFrame containing the processed M5 dataset.
        
    Raises:
        ValueError: If required data files are missing or configuration is invalid.
    """
    try:
        # dataset = pd.read_csv(config['data']['path'])
        # prediction_length = config['data']['prediction_length']
        # dataset = TimeSeriesDataFrame(
        #     data=dataset,
        #     id_column='item_id',
        #     timestamp_column='timestamp'
        # )
        # train_data, test_data = dataset.train_test_split(prediction_length=prediction_length)

        data_config = config['data']
        data_name = data_config['source']
        term = data_config['term']

        if data_name == "m5":
            dataset = M5Wrapper(name="m5")
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "traffic":
            dataset = TrafficWrapper(name="traffic", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "exchange":
            dataset = ExchangeWrapper(name="exchange", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "epf":
            dataset = EPFWrapper(name="epf", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "illness":
            dataset = IllnessWrapper(name="illness", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "sales_synthetic":
            dataset = SaleSyntheticWrapper(name="sales_synthetic", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "nonlinear_sales_synthetic":
            dataset = NonLinearSaleSyntheticWrapper(name="nonlinear_sales_synthetic", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()
        elif data_name == "electricity_synthetic":
            dataset = ElectricityWrapper(name="electricity_synthetic", term=term)
            train_data = dataset.get_train_data()
            test_data = dataset.get_validation_data()

        else:
            raise ValueError(f"Unsupported dataset: {data_name}")


        slice_start = data_config.get('slice_start', None)
        if slice_start is not None:
            train_data = train_data.slice_by_timestep(slice_start, None)
            test_data = test_data.slice_by_timestep(slice_start, None)


        logger.info("====================== Before applying data preprocessing and scaling ======================")
        # Apply scalers if configured
        if 'preprocessing' in data_config:
            logger.info("=====================================================================================")
            logger.info("====================== Applying data preprocessing and scaling ======================")
            logger.info("=====================================================================================")
            train_data, test_data = apply_scalers(train_data, test_data, config)
        return train_data, test_data


    except Exception as e:
        logger.exception("Failed to load and process M5 dataset")
        raise ValueError(f"M5 data loading failed: {str(e)}") from e


def evaluate_forecast(test_data, forecast, config: Dict[str, Any]) -> Dict[str, float]:
    """Evaluate forecast using specified metrics.
    
    Args:
        test_data: Test data as TimeSeriesDataFrame.
        forecast: Forecast as TimeSeriesDataFrame.
        config: Configuration dictionary.
        prediction_length: Length of the prediction.
        target_column: Name of the target column.
        
    Returns:
        Dictionary of metric names and values.
    """
    from autogluon.timeseries.metrics import AVAILABLE_METRICS
    
    metrics = {}
    prediction_length = config['data']['prediction_length']
    
    for metric_name in config['evaluation']['metrics']:
        metric = AVAILABLE_METRICS.get(metric_name)()
        
        # Get target column from model config or use default
        target_column = config.get('target_column', 'target')
        
        score = metric(
            data=test_data,
            predictions=forecast,
            prediction_length=prediction_length,
            target=target_column
        )
        metrics[metric_name] = float(score)
        logger.info("Evaluation metric", metric=metric_name, score=score)
    
    return metrics


def plot_forecast_vs_actual(
    test_data: TimeSeriesDataFrame,
    forecast: TimeSeriesDataFrame,
    metrics: Dict[str, float],
    model_name: str,
    save_path: Path,
    max_items: int = 3,
    figsize: Tuple[int, int] = (15, 10),
    target_column: str = "target"
) -> None:
    """Plot forecasts against actual values for selected items.
    
    Args:
        test_data: Test data containing actual values.
        forecast: Forecast data from the model.
        metrics: Dictionary of evaluation metrics.
        model_name: Name of the model for plot title.
        save_path: Path to save the plot.
        max_items: Maximum number of items to plot (default: 3).
        figsize: Figure size as (width, height) tuple.
        
    Returns:
        None
    """
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    from matplotlib.ticker import MaxNLocator
    
    # Get common item IDs between test data and forecast
    common_item_ids = list(set(test_data.item_ids).intersection(set(forecast.item_ids)))
    
    if not common_item_ids:
        logger.warning("No common item IDs found between test data and forecast")
        return
    
    # Limit to the first max_items
    plot_item_ids = common_item_ids[:max_items]
    # target_column = test_data.columns[0] if 'target' not in test_data.columns else 'target'
    
    # Create figure
    fig, axes = plt.subplots(len(plot_item_ids), 1, figsize=figsize, sharex=False)
    
    # Handle case of single item (axes is not an array)
    if len(plot_item_ids) == 1:
        axes = [axes]
    
    # Get MASE metric if available, otherwise use first metric
    mase_value = metrics.get("MASE", next(iter(metrics.values())) if metrics else None)
    mase_str = f"MASE: {mase_value:.4f}" if mase_value is not None else ""
    
    for i, item_id in enumerate(plot_item_ids):
        ax = axes[i]
        
        # Get data for this item
        item_test = test_data.loc[item_id]
        item_forecast = forecast.loc[item_id]['mean'] if 'mean' in forecast.columns else forecast.loc[item_id][forecast.columns[0]]
        
        # Convert to pandas Series if not already
        if isinstance(item_test, pd.DataFrame):
            item_test = item_test[target_column]
        
        # Plot test data and forecast
        actual_timestamps = item_test.index
        forecast_timestamps = item_forecast.index
        ax.plot(actual_timestamps, item_test.values, 'b-', label='Actual', linewidth=2)
        ax.plot(forecast_timestamps, item_forecast.values, 'r--', label='Forecast', linewidth=2)
        
        # Add grid and legend
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.legend(loc='best')
        
        # Format x-axis to show dates nicely
        if pd.api.types.is_datetime64_any_dtype(actual_timestamps):
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
        
        # Set y-axis to have reasonable number of ticks
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
        
        # Add item ID as title for each subplot
        ax.set_title(f"Item ID: {item_id}", fontsize=10)
        
        # Add y-label
        ax.set_ylabel(target_column)
    
    # Add overall title with metrics
    fig.suptitle(f"Model: {model_name} | {mase_str}", fontsize=14)
    
    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Make room for suptitle
    
    # Save figure
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    
    logger.info(f"Plot saved to {save_path}")


def save_model_config(model_config: Dict[str, Any], model_path: Path) -> None:
    """Save model configuration to the model path.
    
    Args:
        model_config: Model configuration.
        model_path: Path to save configuration.
    
    Returns:
        None
    """
    # Save config to output directory
    with open(model_path / "config.json", 'w') as f:
        json.dump(model_config, f, indent=2)
    
    logger.info("Model configuration saved", path=str(model_path / "config.json"))