
import os
import sys
import argparse
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent.parent
sys.path.append(str(BASE_DIR))

# Argument parser for device selection
parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
parser.add_argument(
    '--task',
    type=str,
    default=None,
    help='Manual task ID.'
)
args, _ = parser.parse_known_args()

# Set environment variables based on parsed arguments
if args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'

import jax
import jax.numpy as jnp
import numpy as onp
import json
from collections import OrderedDict
from models.graph_transformer import GraphTransformer
from data.ala2_data import Ala2Dataset
from ode_solver_diffrax import batched_sampler

def simulate(task_id):
    
    output_path = BASE_DIR / f"output/ala2/{task_id}"
    params = onp.load(output_path / "final_params.npy", allow_pickle=True).item()
    with open(output_path / "config.json", "r") as f:
        config = json.load(f, object_pairs_hook=OrderedDict)
    
    cg_level = config["general"]["cg_level"]
    t0, t1 = config["general"]["t0"], config["general"]["t1"]
    file_path = config["dataset"]["ala2_datafile"]
    feat_type = config["model"]["GraphTransformer"]["feat_type"]
    hidden_layers = config["model"]["GraphTransformer"]["hidden_layers"]
    embedding_layers = config["model"]["GraphTransformer"]["embedding_layers"]
    rescale_time = config["model"]["GraphTransformer"]["rescale_time"]
    clip_time = config["model"]["GraphTransformer"]["clip_time"]
    max_z = config["model"]["GraphTransformer"]["max_z"]
    n_layers = config["model"]["GraphTransformer"]["n_layers"]
    use_intrinsic_coords = config["model"]["GraphTransformer"]["use_intrinsic_coords"]
    use_abs_coords = config["model"]["GraphTransformer"]["use_abs_coords"]
    use_distances = config["model"]["GraphTransformer"]["use_distances"]
    dropout = config["model"]["GraphTransformer"]["dropout"]
    method = config["simulator"]["method"]
    num_batches = config["simulator"]["num_batches"]
    batch_size = config["simulator"]["batch"]
    dt0 = config["simulator"]["dt0"]
    num_z = config["simulator"]["num_z"]
    mean = config["simulator"]["mean"]

    # Data loading
    dataset = Ala2Dataset(file_path, feat_type=feat_type, cg_level=cg_level)
    species = dataset.species[0]
    box = dataset.box[0]
    features = dataset.features
    n_nodes = len(species)
   
  # Model Initialization
    model = GraphTransformer(
        t0=t0,
        t1=t1,
        rescale_time=rescale_time,
        clip_time=clip_time,
        hidden_nf=hidden_layers,
        feature_embedding_dim=embedding_layers,
        max_z=max_z,
        n_layers=n_layers,
        use_intrinsic_coords=use_intrinsic_coords,
        use_abs_coords=use_abs_coords,
        use_distances=use_distances,
        dropout=dropout,
    )

    rng = jax.random.PRNGKey(0)
    rng, rng_sample = jax.random.split(rng, 2)
    x, logp = batched_sampler(
        t0=t0,
        t1=t1,
        dt0=dt0,
        params=params,
        apply_fn=model.apply,
        method=method,
        num_batches=num_batches,
        batch_size=batch_size,
        n_dof=n_nodes * 3,
        features=features,
        num_z=num_z,
        mean=mean,
        rng=rng_sample,
        rtol=1e-4,
        atol=1e-4,
    )
    x = x.reshape((-1, n_nodes, 3))

    # Save results
    output_file = output_path / "proposed_samples.npz"
    onp.savez(output_file, R=x, species=jnp.tile(species, (num_batches * batch_size, 1)), box=jnp.tile(box, (num_batches * batch_size, 1, 1)), logp=logp)
    print(f"Saved {output_file}")
    print(f"Shape of R is {x.shape}")

if __name__ == "__main__":
    simulate(args.task)