from emm.data_gen.mixture.mixture_gen import load_test_dataset
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np


def adjust_nll_for_scaling(results_df: pd.DataFrame, data_dir: str) -> pd.DataFrame:
    """
    Adjust average NLL values to account for data scaling

    Args:
        results_df: DataFrame with 'dataset' and 'nll_loss' columns
                    where 'nll_loss' is the average NLL per sample
        data_dir: Directory containing dataset files

    Returns:
        DataFrame with added columns for corrected NLL values
    """
    # Create a copy to avoid modifying the original
    df = results_df.copy()

    # Add new columns
    df["log_jacobian_term"] = 0.0
    df["adjusted_nll"] = df["nll_loss"]
    df["scaling_factor"] = 1.0
    df["train_nll"] = df["nll_loss"]

    # Process each dataset
    for idx, row in df.iterrows():
        dataset_name = row["dataset"]
        nll_value = row["nll_loss"]

        try:
            # Load the dataset
            X, y, _, _ = load_test_dataset(dataset_name, data_dir)

            # Recreate the same scaler that would have been used in training
            scaler_y = StandardScaler()
            scaler_y.fit(y.reshape(-1, 1))

            # Get scaling factor (std dev)
            sigma = scaler_y.scale_[0]  # Assuming 1D output

            # For average NLL, the correction is simply log(sigma)
            log_jacobian_term = np.log(sigma)

            # Adjust the NLL
            adjusted_nll = nll_value + log_jacobian_term

            # Update the dataframe
            df.at[idx, "log_jacobian_term"] = log_jacobian_term
            df.at[idx, "nll_loss"] = adjusted_nll
            df.at[idx, "scaling_factor"] = sigma

        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {e}")

    return df
