"""
Test for plot_component_importance() visualization function.

This test loads the housing_full.csv dataset, fits an MPF model,
and generates component importance plots.
"""

import os
import sys

import numpy as np
import pytest

# Add parent directory to path to find mpf_py
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from mpf_py import MPF


def get_mpf_hyperparams():
    """Get the specified hyperparameters for MPF model."""
    return {
        "epochs": 6,
        "n_trees": 200,
        "n_iter": 132,
        "split_try": 9,
        "colsample_bytree": 0.3229815917186011,
        "decay": 0.9096693926378157,
        "min_interval_samples": 1,
        "refinement_strategy": "l2_honest",
        "alpha": 0.004313176654642212,
        "update_clamp": 0.7432264054980939,
        "split_strategy": "random",
        "constraint_strategy": "all_positive",
        "identification_strategy": "l2",
        "combination_strategy": "arith_geom_mean",
        "optimize_scaling": False,
        "similarity_threshold": 0.6543511874100095,
        "bagged": True,
        "seed": 42,
    }


def load_housing_data():
    """Load housing_full.csv dataset."""
    # Path from tests/ directory to data/ directory
    # tests/ -> python/ -> mpf-py/ -> MPF/ -> data/
    # Get the project root by going up from tests/ to MPF/
    test_dir = os.path.dirname(__file__)
    project_root = os.path.abspath(os.path.join(test_dir, "..", "..", ".."))
    data_path = os.path.join(project_root, "data", "housing_full.csv")

    if not os.path.exists(data_path):
        pytest.skip(f"Could not find housing_full.csv at {data_path}")

    # Load CSV (no header, last column is target)
    data = np.loadtxt(data_path, delimiter=",", dtype=np.float64)

    # Split into features and target
    X = data[:, :-1]
    y = data[:, -1]

    # Ensure contiguous arrays for MPF
    X = np.ascontiguousarray(X)
    y = np.ascontiguousarray(y)

    return X, y


@pytest.fixture(scope="module")
def fitted_model():
    """Fixture that fits the MPF model once for all tests."""
    X, y = load_housing_data()
    hyperparams = get_mpf_hyperparams()
    model, fit_result = MPF.fit(X, y, **hyperparams)
    return model, fit_result, X, y


def test_plot_component_importance_basic(fitted_model):
    """Test plot_component_importance() with basic usage (scaling factors only)."""
    model, fit_result, X, y = fitted_model

    # Test basic plotting (scaling factors only)
    # Note: This will show plots, but we can't easily test the visual output
    # We mainly test that it doesn't raise exceptions
    try:
        model.plot_component_importance(show_variance=False)
    except Exception as e:
        pytest.fail(f"plot_component_importance() raised {type(e).__name__}: {e}")


def test_plot_component_importance_with_variance(fitted_model):
    """Test plot_component_importance() with variance explained."""
    model, fit_result, X, y = fitted_model

    # Test plotting with variance explained
    try:
        model.plot_component_importance(X, y, show_variance=True)
    except Exception as e:
        pytest.fail(
            f"plot_component_importance() with variance raised {type(e).__name__}: {e}"
        )


def test_plot_component_importance_custom_figsize(fitted_model):
    """Test plot_component_importance() with custom figure size."""
    model, fit_result, X, y = fitted_model

    # Test with custom figure size
    try:
        model.plot_component_importance(X, y, figsize=(14, 6), show_variance=True)
    except Exception as e:
        pytest.fail(
            f"plot_component_importance() with custom figsize raised {type(e).__name__}: {e}"
        )


def test_plot_component_importance_validation(fitted_model):
    """Test that plot_component_importance() validates inputs correctly."""
    model, fit_result, X, y = fitted_model

    # Test that providing x without y raises an error
    with pytest.raises(ValueError, match="y must be provided if x is provided"):
        model.plot_component_importance(x=X)

    # Test that it works with both x and y
    try:
        model.plot_component_importance(X, y)
    except Exception as e:
        pytest.fail(
            f"plot_component_importance() with x and y raised {type(e).__name__}: {e}"
        )


if __name__ == "__main__":
    # Allow running the test directly for manual inspection
    # Create the fixture manually for direct execution
    print("Loading data and fitting model (this may take a while)...")
    X, y = load_housing_data()
    hyperparams = get_mpf_hyperparams()
    model, fit_result = MPF.fit(X, y, **hyperparams)
    fitted_model = (model, fit_result, X, y)

    print("Running component importance plot tests...")

    # Test 1: Basic plot
    print("\n1. Testing basic component importance plot (scaling factors only)...")
    test_plot_component_importance_basic(fitted_model)
    print("   ✓ Basic plot test passed")

    # Test 2: Plot with variance
    print("\n2. Testing component importance plot with variance explained...")
    test_plot_component_importance_with_variance(fitted_model)
    print("   ✓ Variance plot test passed")

    # Test 3: Custom figsize
    print("\n3. Testing component importance plot with custom figure size...")
    test_plot_component_importance_custom_figsize(fitted_model)
    print("   ✓ Custom figsize test passed")

    # Test 4: Validation
    print("\n4. Testing input validation...")
    test_plot_component_importance_validation(fitted_model)
    print("   ✓ Validation test passed")

    print("\n✅ All tests passed!")
