import numpy as np
import pickle
from pathlib import Path
from scipy.stats import norm, uniform, gamma, expon
import matplotlib.pyplot as plt
import torch
from typing import List, Dict, Any
from emm.algorithms.mixture import get_log_prob

from sklearn.mixture import GaussianMixture


def add_R2_oracle(df, data_dir: str, **kwargs):
    baseline_dict = calculate_gmm_baseline_nll(data_dir)
    nll_dict = calculate_true_nll(data_dir)
    df["baseline_nll"] = df["dataset"].apply(lambda x: baseline_dict[x])
    df["oracle_nll"] = df["dataset"].apply(lambda x: nll_dict[x])
    df["r2_oracle"] = (df["test_nll"] - df["baseline_nll"]) / (
        df["oracle_nll"] - df["baseline_nll"]
    )
    return df


def calculate_gmm_baseline_nll(
    data_dir: str, n_components: int = 20
) -> Dict[str, float]:
    """
    Calculates the baseline negative log-likelihood (NLL) using a GMM.

    This function fits a Gaussian Mixture Model on the target variable 'y' for
    each dataset, ignoring the features 'X'. It serves as a consistent baseline
    to compare conditional models against.

    Args:
        data_dir: Directory containing the dataset pickle files.
        n_components: The number of components to use for the GMM.

    Returns:
        A dictionary mapping each dataset name to its baseline GMM NLL value.
    """
    data_path = Path(data_dir)
    # Recursively find all directories containing a 'dataset.pkl' file
    dataset_paths = sorted([p.parent for p in data_path.glob("**/dataset.pkl")])
    results = {}

    print(f"Calculating GMM baseline with {n_components} components...")

    for dataset_path in dataset_paths:
        dataset_name = dataset_path.name
        try:
            # Load the dataset
            file_path = dataset_path / "dataset.pkl"
            if not file_path.exists():
                print(f"Warning: Dataset {dataset_name} not found, skipping.")
                continue

            with open(file_path, "rb") as f:
                dataset = pickle.load(f)

            # Ensure y is a 2D array for sklearn
            y = dataset["y"].reshape(-1, 1)
            test_indices = dataset["test_indices"]
            train_indices = dataset["train_indices"]

            # 1. Initialize and fit the GMM
            # A random_state is used for reproducible initializations.
            gmm = GaussianMixture(
                n_components=n_components,
                random_state=0,
                init_params="k-means++",
                max_iter=1000,
                n_init=3,
            )
            gmm.fit(y[train_indices])

            # 2. Calculate the total log-likelihood
            # gmm.score() returns the *average* log-likelihood per sample.
            # We multiply by the number of samples to get the total log-likelihood.
            total_log_likelihood = gmm.score(y[test_indices])  # * len(y)

            # 3. Convert to Negative Log-Likelihood (NLL)
            nll = -total_log_likelihood
            results[dataset_name] = nll

        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {str(e)}")

    return results


def calculate_true_nll(
    data_dir: str, dataset_names=None, on_test=True
) -> Dict[str, float]:
    """
    Calculate the negative log-likelihood (NLL) of the true model for each dataset.

    Args:
        data_dir: Directory containing the dataset pickle files
        dataset_names: List of dataset names to process. If None, process all .pkl files

    Returns:
        Dictionary mapping dataset names to their true NLL value
    """
    data_path = Path(data_dir)

    # Get all dataset files if none specified
    if dataset_names is None:
        # dataset_names = [p.stem for p in data_path.glob("*.pkl")]
        dataset_paths = sorted([p.parent for p in data_path.glob("**/dataset.pkl")])
    else:
        raise NotImplementedError("Currently, only dataset_names=None is supported.")
    results = {}

    for dataset_path in dataset_paths:
        dataset_name = dataset_path.name
        try:
            # Load the dataset
            file_path = dataset_path / "dataset.pkl"
            if not file_path.exists():
                print(f"Warning: Dataset {dataset_name} not found, skipping")
                continue

            with open(file_path, "rb") as f:
                dataset = pickle.load(f)

            X = dataset["X"]
            y = dataset["y"].flatten()  # Ensure y is a 1D array
            component_labels = dataset["component_labels"]
            config = dataset["config"]
            test_indices = dataset["test_indices"]
            if on_test:
                # Use only the test set if specified
                X = X[test_indices]
                y = y[test_indices]
                component_labels = component_labels[test_indices]
            # Calculate NLL for this dataset
            nll = calculate_dataset_nll(X, y, component_labels, config.components)
            results[dataset_name] = nll

            # print(f"Dataset: {dataset_name}, True NLL: {nll:.4f}")

        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {str(e)}")

    return results


def calculate_dataset_nll(
    X: np.ndarray, y: np.ndarray, component_labels: np.ndarray, components: List
) -> float:
    """
    Calculate the negative log-likelihood for a dataset with known component assignments.

    Args:
        X: Feature matrix (n_samples, n_features)
        y: Target values (n_samples)
        component_labels: Array indicating which component generated each sample
        components: List of Component objects that generated the data

    Returns:
        The negative log-likelihood of the data given the true model
    """
    n_samples = len(y)
    log_likelihoods = np.zeros(n_samples)

    for i in range(n_samples):
        component = components[component_labels[i]]
        distribution = component.distribution.lower()
        params = component.dist_params

        # Calculate log likelihood of the target value given the component's distribution
        if distribution == "normal":
            log_likelihoods[i] = norm.logpdf(
                y[i], loc=params["loc"], scale=params["scale"]
            )
        elif distribution == "uniform":
            log_likelihoods[i] = uniform.logpdf(
                y[i], loc=params["low"], scale=params["high"] - params["low"]
            )

        elif distribution == "exponential":
            log_likelihoods[i] = expon.logpdf(
                y[i], scale=params["scale"], loc=params.get("loc", 0)
            )
        elif distribution == "gamma":
            log_likelihoods[i] = gamma.logpdf(
                y[i],
                a=params["shape"],
                scale=params["scale"],
                loc=params.get("loc", 0),
            )
        # replace all -inf with small value
    np.clip(log_likelihoods, -1000, None, out=log_likelihoods)
    # Calculate the mean negative log-likelihood
    nll = -np.mean(log_likelihoods)
    return nll


from typing import List, Tuple, Dict  # Assuming Any for Component type for now
import numpy as np
from scipy.special import logsumexp  # Required new import

# It's assumed that a Component object/dict would have at least:
# component.distribution: str
# component.dist_params: Dict[str, float]
# component.weight: float


def calculate_marginal_dataset_nll(
    X: np.ndarray, y: np.ndarray, components: List[Any]
) -> Tuple[float, np.ndarray]:
    """
    Calculate the marginal negative log-likelihood for a dataset under a mixture model.

    This function computes P(y_i | Model) = sum_k P(y_i | C_k, params_k) * P(C_k),
    where the sum is over all components C_k in the mixture model.
    The final NLL is - (1/n_samples) * sum_i log(P(y_i | Model)).

    Args:
        X: Feature matrix (n_samples, n_features). Note: this argument is
           included for consistency with related functions but is not directly
           used in this NLL calculation which depends only on y.
        y: Target values (n_samples,)
        components: List of component objects/dictionaries. Each component must have
                    'distribution' (str), 'dist_params' (dict),
                    and 'weight' (float) attributes/keys.

    Returns:
        A tuple containing:
            - The marginal negative log-likelihood of the data.
            - An array of marginal log-likelihoods for each sample y_i.
    """
    n_samples = len(y)

    if n_samples == 0:
        return 0.0, np.array([])

    if not components:
        # If there are no components, the likelihood of any data point is 0.
        # Log-likelihood is -inf, so NLL is +inf.
        return np.inf, np.full(n_samples, -np.inf)

    sample_log_likelihoods = np.zeros(n_samples)

    for i in range(n_samples):
        # For each sample y_i, we calculate log P(y_i|Model) =
        # log(sum_k P(y_i|C_k) * P(C_k))
        # This is log(sum_k exp(log P(y_i|C_k) + log P(C_k)))
        # which is computed using logsumexp([log P(y_i|C_k) + log P(C_k) for k in components])

        terms_for_logsumexp = []
        for component in components:
            distribution = component.distribution.lower()
            params = component.dist_params
            log_pdf_val = -np.inf  # Default to -inf for safety

            if distribution == "normal":
                log_pdf_val = norm.logpdf(
                    y[i], loc=params["loc"], scale=params["scale"]
                )
            elif distribution == "uniform":
                # scipy.stats.uniform.logpdf handles scale (width) <= 0 by returning -np.inf
                width = params["high"] - params["low"]
                log_pdf_val = uniform.logpdf(y[i], loc=params["low"], scale=width)
            elif distribution == "exponential":
                log_pdf_val = expon.logpdf(
                    y[i], scale=params["scale"], loc=params.get("loc", 0)
                )
            elif distribution == "gamma":
                log_pdf_val = gamma.logpdf(
                    y[i],
                    a=params["shape"],
                    scale=params["scale"],
                    loc=params.get("loc", 0),
                )
            else:
                # Or raise an error for unknown distributions
                # For now, if a distribution is unknown, its contribution will be exp(-inf) = 0
                # Consider raising ValueError(f"Unknown distribution: {distribution}")
                pass  # log_pdf_val remains -np.inf

            if component.weight > 0:  # Avoid log(0) if weight is 0
                log_component_weight = np.log(component.weight)
                terms_for_logsumexp.append(log_pdf_val + log_component_weight)
            # If component.weight is 0, it contributes nothing to the sum of probabilities.

        if not terms_for_logsumexp:
            # This case occurs if all component weights are zero.
            # Likelihood is 0, log-likelihood is -inf.
            sample_log_likelihoods[i] = -np.inf
        else:
            sample_log_likelihoods[i] = logsumexp(terms_for_logsumexp)

    # Calculate the mean negative log-likelihood
    # np.mean of -np.inf values is -np.inf. So -np.mean is np.inf.
    nll = -np.mean(sample_log_likelihoods)

    return nll


def get_nll_report(data_dir: str, dataset_names: List[str] = None) -> str:
    """
    Generate a formatted report of true NLL values for the specified datasets.

    Args:
        data_dir: Directory containing the dataset pickle files
        dataset_names: List of dataset names to process. If None, process all .pkl files

    Returns:
        Formatted string with the results in tabular format
    """
    results = calculate_true_nll(data_dir, dataset_names)

    # Sort datasets by name
    sorted_datasets = sorted(results.keys())

    # Generate report
    report = "Dataset True NLL Results\n"
    report += "======================\n\n"
    report += "| Dataset Name | True NLL |\n"
    report += "|-------------|--------:|\n"

    for dataset in sorted_datasets:
        report += f"| {dataset} | {results[dataset]:.4f} |\n"

    return report


def plot_true_and_fitted_densities(
    X: np.ndarray,
    Y: np.ndarray,
    component_flows: List,
    scaler_y,
    dataset_config: Dict[str, Any],
    title: str = "True vs Fitted Densities",
    show_legend: bool = True,
):
    """
    Plot both the true component densities and the fitted normalizing flow densities.

    Args:
        X: Feature matrix (n_samples, n_features)
        Y: Target values (n_samples, 1)
        component_flows: List of fitted normalizing flows
        scaler_y: Scaler used to normalize Y values
        dataset_config: Dataset configuration containing component information
        title: Plot title
        show_legend: Whether to show the legend
    """
    # Create a range of y values to evaluate densities on
    y_range = np.linspace(Y.min() - 2, Y.max() + 2, 1000).reshape(-1, 1)
    y_range_scaled = torch.tensor(scaler_y.transform(y_range), dtype=torch.float64)
    # y_range_scaled = torch.tensor(y_range, dtype=torch.float64)

    plt.figure(figsize=(12, 6))

    # Plot fitted densities from flows
    for i, flow in enumerate(component_flows):
        with torch.no_grad():
            # Get density in scaled space
            scaled_density = (get_log_prob(flow, y_range_scaled)).exp().numpy()

            # CORRECTION: Apply Jacobian adjustment for the scaling transformation
            if hasattr(scaler_y, "scale_") and hasattr(scaler_y, "mean_"):
                # For StandardScaler, the Jacobian determinant is 1/std
                scale_factor = 1.0 / scaler_y.scale_[0]  # Assuming 1D output
                # Adjust density by multiplying with the Jacobian determinant
                density = scaled_density * scale_factor
            else:
                # For other scalers, use a numerical approximation if necessary
                density = scaled_density  # Default fallback

            density = density.reshape(-1)

        plt.plot(
            y_range.reshape(-1),
            density,
            label=f"Fitted Component {i+1}",
            linestyle="-",
            linewidth=2,
            zorder=2,
        )

    # Plot true densities from components
    components = dataset_config.components
    for i, component in enumerate(components):
        distribution = component.distribution.lower()
        params = component.dist_params
        y_flat = y_range.flatten()

        if distribution == "normal":
            density = norm.pdf(y_flat, loc=params["loc"], scale=params["scale"])
        elif distribution == "uniform":
            density = uniform.pdf(
                y_flat, loc=params["low"], scale=params["high"] - params["low"]
            )
        elif distribution == "exponential":
            density = expon.pdf(y_flat, scale=params["scale"])
        elif distribution == "gamma":
            density = gamma.pdf(y_flat, a=params["shape"], scale=params["scale"])
        else:
            continue  # Skip unknown distributions

        plt.plot(
            y_flat,
            density,
            label=f"True {distribution.capitalize()} {i+1}",
            linestyle="--",
            linewidth=2,
            zorder=1,
        )

    # Add component mixing weights if available
    if hasattr(dataset_config, "mixing_weights"):
        mixing_info = f"Mixing weights: {', '.join([f'{w:.2f}' for w in dataset_config.mixing_weights])}"
        plt.annotate(
            mixing_info,
            xy=(0.5, 0.01),
            xycoords="axes fraction",
            ha="center",
            fontsize=10,
        )

    # Add histogram of actual data with low alpha for context
    plt.hist(Y, bins=50, density=True, alpha=0.3, color="gray", label="Data histogram")

    plt.title(title)
    plt.xlabel("y value")
    plt.ylabel("Density")
    if show_legend:
        plt.legend(loc="best")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    # plt.savefig("figs_mixture/pdf_comparison.jpg", dpi=600)
    plt.show()
