import os
import sys
import argparse
from pathlib import Path
import json

# Add parent directory to sys.path
BASE_DIR = Path(__file__).resolve().parent.parent
sys.path.append(str(BASE_DIR))

# Argument parser for device selection
parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
parser.add_argument(
    '--task',
    type=str,
    default=None,
    help='Manual task ID.'
)
args, _ = parser.parse_known_args()

# Set environment variables based on parsed arguments
if args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.45'

import jax
import jax.numpy as jnp
import numpy as onp
from tqdm import tqdm
from train_utils import train_step, create_mb_train_state
from config import get_default_config
from models.rbf_mlp import MLP
from data.mb_data import get_mb_dataloader



def train(task_id):
    config = get_default_config()
    mu = config["general"]["mu"]
    sigma = config["general"]["sigma"]
    t0 = config["general"]["t0"]
    t1 = config["general"]["t1"]
    hidden_layers = config["model"]["MLP"]["hidden_layers"]
    embedding_layers = config["model"]["MLP"]["embedding_layers"]
    n_layers = config["model"]["MLP"]["n_layers"]
    epochs = config["trainer"]["epochs"]
    batch_size = config["trainer"]["batch"]
    
    # Initialize dataloader and model
    dataloader = get_mb_dataloader(config)
    model = MLP(
        hidden_layers=hidden_layers,
        embedding_layers=embedding_layers,
        n_layers=n_layers,
    )

    # Training loop
    rng = jax.random.PRNGKey(0)
    train_loss = 0.0
    train_loss_set = []
    state = create_mb_train_state(
        rng=rng, 
        model=model, 
        config=config,
    )
    pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    for _ in pbar:
        for batch in dataloader:
            
            x_batch = jnp.asarray(batch)
            
            rng, rng1, rng2 = jax.random.split(rng, 3)
            x_init = jax.random.normal(rng1, x_batch.shape) * sigma + mu
            t = jax.random.uniform(rng2, shape=(batch_size,1), minval=t0, maxval=t1)

            loss, state = train_step(
                state=state,
                x=x_batch,
                x_init=x_init,
                t=t,
            )
            train_loss += loss

        train_loss /= len(dataloader)
        pbar.set_postfix({
            "loss": f"{train_loss:.6f}"
        })
        train_loss_set.append(train_loss)
        train_loss = 0.0

    train_loss_set = jnp.stack(train_loss_set)

    # Save parameters
    output_path = BASE_DIR / f"output/mb/{task_id}"
    os.makedirs(output_path, exist_ok=True)
    config_file = output_path / "config.json"
    output_file = output_path / "final_params.npy"
    loss_file = output_path / "loss.npy"
    onp.save(output_file, state.params)
    onp.save(loss_file, train_loss_set)
    with open(config_file, "w") as f:
        json.dump(config, f, indent=2)
    print(f"Models saved to {output_file}")
    print(f"Training loss saved to {loss_file}")

if __name__ == "__main__":
    train(task_id=args.task)

    