import sys
import argparse
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent.parent
sys.path.append(str(BASE_DIR))

parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
parser.add_argument(
    '--task',
    type=str,
    default=None,
    help='Manual task ID.'
)
args, _ = parser.parse_known_args()
print(f"Running on device: {args.device}")

import numpy as onp
from rbf_mlp import RBFMLP
from train import get_default_config

def predict(task_id, params):
    config = get_default_config()
    hidden_layers = config["model"]["features"]

    model = RBFMLP(
        hidden_layers=hidden_layers,
        num_rbf_centers=100,
        sigma=5.0
    )

    def energy_fn(x):
        return model.apply(params, x)
    
    dataset = onp.load(BASE_DIR / f"output/mb/{task_id}/proposed_samples.npz")
    x = dataset['R'][..., 0].reshape(-1, 1)
    logp = dataset['logp'] if onp.any(dataset['logp'] != None) else None
    energy = energy_fn(x)

    save_path = BASE_DIR / f"output/mb/{task_id}/predicted_energy.npz"
    onp.savez(save_path, R=x, U=energy, logp=logp)
    print(f"Saved energy to {save_path}")

if __name__ == "__main__":
    task_id = args.task
    params_path = ""
    params = onp.load(params_path, allow_pickle=True).item()
    predict(task_id, params)
