import sys
from pathlib import Path

# project root = parent of "scripts"
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import numpy as np
import torch
from source.utils.uncertainty_measures import (
    calculate_uncertainties_crps,
    calculate_uncertainties_log,
    calculate_uncertainties_quadratic,
    calculate_uncertainties_mse,
)
def print_test_result(test_name, condition, error_msg=""):
    """Helper function to print test results instead of terminating"""
    if condition:
        print(f"✓ PASSED: {test_name}")
    else:
        print(f"✗ FAILED: {test_name} - {error_msg}")

def test_uncertainty_measure_additivity(uncertainty_func, func_name):
    """Test additivity properties for a given uncertainty measure function"""
    print(f"\n=== Testing {func_name} ===")
    
    # Generate artificial dummy data
    torch.manual_seed(42)  # For reproducibility
    N_members, N_objects = 10, 5
    
    # Create means and variances for the ensemble
    means = torch.randn(N_objects, N_members) * 2.0 + 1.0
    variances = torch.abs(torch.randn(N_objects, N_members)) * 0.5 + 0.1
    
    # Calculate uncertainties
    try:
        results = uncertainty_func(means, variances)
        
        # Extract relevant measures (excluding those with "3" in their names)
        R_tot_1_1 = results["total_1_1"].cpu().numpy()
        R_tot_2_1 = results["total_2_1"].cpu().numpy()
        
        R_exc_1_1 = results["excess_1_1"].cpu().numpy()
        R_exc_2_1 = results["excess_2_1"].cpu().numpy()
        
        R_bay_1 = results["bayes_1"].cpu().numpy()
        R_bay_2 = results["bayes_2"].cpu().numpy()

        print_test_result(
            f"{func_name}: R_exc_1_1 >= 0",
            np.all(R_exc_1_1 >= 0),
            f"Min value: {np.min(R_exc_1_1)}"
        )

        print_test_result(
            f"{func_name}: R_exc_2_1 >= 0",
            np.all(R_exc_2_1 >= 0),
            f"Min value: {np.min(R_exc_2_1)}"
        )
        
        print_test_result(
            f"{func_name}: R_tot_1_1 = R_bay_1 + R_exc_1_1",
            np.allclose(R_tot_1_1, R_bay_1 + R_exc_1_1, atol=1e-6),
            f"Max diff: {np.max(np.abs(R_tot_1_1 - (R_bay_1 + R_exc_1_1)))}"
        )
        
        print_test_result(
            f"{func_name}: R_tot_2_1 = R_bay_2 + R_exc_2_1",
            np.allclose(R_tot_2_1, R_bay_2 + R_exc_2_1, atol=1e-6),
            f"Max diff: {np.max(np.abs(R_tot_2_1 - (R_bay_2 + R_exc_2_1)))}"
        )

        print_test_result(
            f"{func_name}: R_tot_1_1 = R_tot_2_1",
            np.allclose(R_tot_1_1, R_tot_2_1, atol=1e-6),
            f"Max diff: {np.max(np.abs(R_tot_1_1 - R_tot_2_1))}"
        )
        

        if not np.isnan(R_bay_2).any():
            print_test_result(
                f"{func_name}: R_exc_1_1 >= R_exc_2_1 (outer >= inner)",
                np.all(R_exc_1_1 >= R_exc_2_1),
                f"Max violation: {np.max(R_exc_2_1 - R_exc_1_1)}"
            )
        else:
            print(f"⚠ SKIPPED: {func_name}: R_exc_1_1 >= R_exc_2_1 (R_bay_2 contains NaN)")
            

    except Exception as e:
        print(f"✗ ERROR in {func_name}: {str(e)}")

def test_cross_measure_consistency():
    """Test consistency across different uncertainty measures where possible"""
    print(f"\n=== Cross-measure consistency tests ===")
    torch.manual_seed(42)
    N_members, N_objects = 8, 4
    
    means = torch.randn(N_objects, N_members) * 1.5 + 0.5
    variances = torch.abs(torch.randn(N_objects, N_members)) * 0.3 + 0.05
    
    try:
        crps_results = calculate_uncertainties_crps(means, variances)
        mse_results = calculate_uncertainties_mse(means, variances)
        
        # Test that both measures give consistent decomposition structure
        for key in ["total_1_1", "bayes_1", "excess_1_1"]:
            crps_val = crps_results[key]
            mse_val = mse_results[key]
            
            # Both should be finite
            print_test_result(
                f"CRPS {key} is finite",
                torch.isfinite(crps_val).all(),
                f"Contains non-finite values"
            )
            
            print_test_result(
                f"MSE {key} is finite",
                torch.isfinite(mse_val).all(),
                f"Contains non-finite values"
            )
            
    except Exception as e:
        print(f"✗ ERROR in cross-measure consistency: {str(e)}")

if __name__ == "__main__":
    print("Testing Uncertainty Measure Additivity Properties")
    print("=" * 50)
    
    # Test each uncertainty measure function
    uncertainty_functions = [
        (calculate_uncertainties_crps, "CRPS"),
        (calculate_uncertainties_log, "Log Score"),
        (calculate_uncertainties_mse, "MSE"),
        (calculate_uncertainties_quadratic, "Quadratic Score"),
    ]
    
    for func, name in uncertainty_functions:
        test_uncertainty_measure_additivity(func, name)
    
    # Test cross-measure consistency
    test_cross_measure_consistency()
    
    print(f"\n" + "=" * 50)
    print("Testing completed. Check output above for any failures.")
