"""
Example usage of FIRE and baseline methods for multi-fidelity regression.

This example demonstrates:
1. Generating synthetic multi-fidelity data
2. Training FIRE and baseline methods
3. Evaluating predictions

Note: Some methods require specific hardware (GPU) or packages.
"""

import numpy as np
import torch
import mf2
from src.util import *


def generate_synthetic_data(n_low=200, n_high=10, n_test=100, seed=42):
    """
    Generate synthetic 2-fidelity data using borehole function as an example.
    """
    np.random.seed(seed)
    
    function = mf2.borehole
    dim = len(function.l_bound)

    # Low-fidelity data
    X_lf = np.random.rand(200, dim)
    y_lf = function.low(X_lf * (function.u_bound - function.l_bound) + function.l_bound)

    # High-fidelity data (disjoint, non-nested design)
    X_hf = np.random.rand(n_high, dim)
    y_hf = function.high(X_hf * (function.u_bound - function.l_bound) + function.l_bound)

    # Test data
    X_test = np.random.rand(n_test, dim)
    y_test = function.high(X_test * (function.u_bound - function.l_bound) + function.l_bound)

    return X_lf, y_lf, X_hf, y_hf, X_test, y_test


def main():
    print("=" * 60)
    print("FIRE: Multi-fidelity Regression Example")
    print("=" * 60)

    # Generate synthetic data
    print("\n[0] Generating synthetic multi-fidelity data...")
    X_lf, y_lf, X_hf, y_hf, X_test, y_test = generate_synthetic_data(
        n_low=100, n_high=10, n_test=50, seed=42
    )

    print(f"    Low-fidelity:  {X_lf.shape[0]} samples")
    print(f"    High-fidelity: {X_hf.shape[0]} samples")
    print(f"    Test:          {X_test.shape[0]} samples")

    # Convert to tensors
    X_lf = torch.tensor(X_lf, dtype=torch.float32)
    y_lf = torch.tensor(y_lf, dtype=torch.float32)
    X_hf = torch.tensor(X_hf, dtype=torch.float32)
    y_hf = torch.tensor(y_hf, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_test_np = y_test.copy()

    # Encode data for general methods (with fidelity column)
    train_X, train_y, _, _ = encode_2fidelity_data(
        X_lf.numpy(), y_lf.numpy(), X_hf.numpy(), y_hf.numpy(),
        X_test.numpy(), y_test_np, preprocess_X=False, preprocess_Y=False
    )

    results = {}
    
    # =====================================================
    # Method 1: FIRE (works on GPU)
    # =====================================================
    print("\n[2] Training FIRE...")
    try:
        from src.FIRE import FIRE
        model = FIRE(X_lf, y_lf, X_hf, y_hf, device='cpu', seed=42)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['FIRE'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Method 1: FIRE_GP (works on CPU)
    # =====================================================
    print("\n[2] Training FIRE_GP...")
    try:
        from src.FIRE import FIRE_GP
        model = FIRE_GP(X_lf, y_lf, X_hf, y_hf, device='cpu', seed=42)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['FIRE_GP'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Method 2: ResGP
    # =====================================================
    print("\n[3] Training ResGP...")
    try:
        from src.ResGP import ResGP
        model = ResGP(X_lf, y_lf, X_hf, y_hf, device='cpu', seed=42, train_iter=100)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['ResGP'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Method 3: NARGP
    # =====================================================
    print("\n[4] Training NARGP...")
    try:
        from src.NARGP import NARGP
        model = NARGP(X_lf, y_lf, X_hf, y_hf, device='cpu', seed=42, train_iter=100)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['NARGP'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Method 4: AR1
    # =====================================================
    print("\n[5] Training AR1...")
    try:
        from src.AR1 import AR1
        model = AR1(train_X, train_y, fidelity_col_idx=-1, device='cpu', seed=42, train_iter=100)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['AR1'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Method 5: ContinuAR
    # =====================================================
    print("\n[6] Training ContinuAR...")
    try:
        from src.ContinuAR import ContinuAR
        model = ContinuAR(train_X, train_y, fidelity_col_idx=-1, device='cpu', seed=42, train_iter=100)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['ContinuAR'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Method 6: MFRNP
    # =====================================================
    print("\n[7] Training MFRNP...")
    try:
        from src.MFRNP import MFRNP
        model = MFRNP(train_X, train_y, fidelity_col_idx=-1, device='cpu', seed=42, train_iter=100)
        y_pred, y_var = model.predict(X_test)
        r2, nrmse, nll = get_metrics(y_test_np, y_pred, y_var)
        results['MFRNP'] = {'R2': r2, 'NRMSE': nrmse, 'NLL': nll}
        print(f"    R²: {r2:.4f}, NRMSE: {nrmse:.4f}, NLL: {nll:.4f}")
    except Exception as e:
        print(f"    Error: {e}")

    # =====================================================
    # Summary
    # =====================================================
    print("\n" + "=" * 60)
    print("Summary of Results")
    print("=" * 60)
    print(f"{'Method':<15} {'R²':>10} {'NRMSE':>10} {'NLL':>10}")
    print("-" * 45)
    for method, metrics in results.items():
        print(f"{method:<15} {metrics['R2']:>10.4f} {metrics['NRMSE']:>10.4f} {metrics['NLL']:>10.4f}")

    print("\n" + "=" * 60)
    print("Notes:")
    print("- FIRE_TFM requires GPU and TabPFN package")
    print("- MFKG requires GPU and BoTorch package")
    print("- MFBNN requires mfbml package")
    print("=" * 60)


if __name__ == "__main__":
    main()
