import os
import sys
import argparse
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent.parent
sys.path.append(str(BASE_DIR))

# Argument parser for device selection
parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
parser.add_argument(
    '--task',
    type=str,
    default=None,
    help='Manual task ID.'
)
args, _ = parser.parse_known_args()

# Set environment variables based on parsed arguments
if args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.45'

import jax
import numpy as onp
from models.rbf_mlp import MLP
import json
from collections import OrderedDict
from ode_solver_diffrax import batched_sampler

def simulate(task_id):
    output_path = BASE_DIR / f"output/mb/{task_id}"
    os.makedirs(output_path, exist_ok=True)
    params = onp.load(output_path / "final_params.npy", allow_pickle=True).item()
    with open( output_path / "config.json", "r") as f:
        config = json.load(f, object_pairs_hook=OrderedDict)

    t0, t1 = config["general"]["t0"], config["general"]["t1"]
    hidden_layers = config["model"]["MLP"]["hidden_layers"]
    embedding_layers = config["model"]["MLP"]["embedding_layers"]
    n_layers = config["model"]["MLP"]["n_layers"]
    method = config["simulator"]["method"]
    num_batches = config["simulator"]["num_batches"]
    batch_size = config["simulator"]["batch"]
    dt0 = config["simulator"]["dt0"]
    num_z = config["simulator"]["num_z"]
    mean = config["simulator"]["mean"]
    
    # Determine input dimension
    x_dim = 1  # For CG MB potential

    # Initialize model
    model = MLP(
        # t0=t0,
        # t1=t1,
        hidden_layers=hidden_layers,
        embedding_layers=embedding_layers,
        n_layers=n_layers,
    )
    
    apply_fn = lambda params, x, t, **kwargs: model.apply(params, x, t)

    rng = jax.random.PRNGKey(0)
    rng, rng_sample = jax.random.split(rng, 2)
    x, logp = batched_sampler(
        t0=t0,
        t1=t1,
        dt0=dt0,
        params=params,
        apply_fn=apply_fn,
        method=method,
        num_batches=num_batches,
        batch_size=batch_size,
        n_dof=x_dim,
        features=None,
        num_z=num_z,
        mean=mean,
        rng=rng_sample,
        rtol=1e-5,
        atol=1e-5,
    )


    # Save results
    output_file = output_path / "proposed_samples.npz"
    onp.savez(output_file, R=x, logp=logp)
    print(f"Saved {output_file}")
    print(f"Shape of R is {x.shape}")

if __name__ == "__main__":
    simulate(task_id=args.task)