import json
import os
import sys
import argparse
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent.parent
sys.path.append(str(BASE_DIR))

# Argument parser for device selection
parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
parser.add_argument(
    '--task',
    type=str,
    default=None,
    help='Manual task ID.'
)
args, _ = parser.parse_known_args()

# Set environment variables based on parsed arguments
if args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'

import jax
import jax.numpy as jnp
import numpy as onp
from tqdm import tqdm
from train_utils import train_step, create_ala2_train_state
from config import get_default_config
from models.graph_transformer import GraphTransformer
from data.ala2_data import get_ala2_dataloader, Ala2Dataset

def train(task_id):
    config = get_default_config()
    cg_level = config["general"]["cg_level"]
    mu = config["general"]["mu"]
    sigma = config["general"]["sigma"]
    t0 = config["general"]["t0"]
    t1 = config["general"]["t1"]
    file_path = config["dataset"]["ala2_datafile"]
    feat_type = config["model"]["GraphTransformer"]["feat_type"]
    hidden_layers = config["model"]["GraphTransformer"]["hidden_layers"]
    embedding_layers = config["model"]["GraphTransformer"]["embedding_layers"]
    max_z = config["model"]["GraphTransformer"]["max_z"]
    n_layers = config["model"]["GraphTransformer"]["n_layers"]
    rescale_time = config["model"]["GraphTransformer"]["rescale_time"]
    clip_time = config["model"]["GraphTransformer"]["clip_time"]
    use_intrinsic_coords = config["model"]["GraphTransformer"]["use_intrinsic_coords"]
    use_abs_coords = config["model"]["GraphTransformer"]["use_abs_coords"]
    use_distances = config["model"]["GraphTransformer"]["use_distances"]
    dropout = config["model"]["GraphTransformer"]["dropout"]
    epochs = config["trainer"]["epochs"]
    batch_size = config["trainer"]["batch"]
    ema_decay = config["trainer"].get("ema_decay", None)
    
    # Initialize dataloader and model
    dataloader = get_ala2_dataloader(config)
    model = GraphTransformer(
        t0=t0,
        t1=t1,
        rescale_time=rescale_time,
        clip_time=clip_time,
        hidden_nf=hidden_layers,
        feature_embedding_dim=embedding_layers,
        max_z=max_z,
        n_layers=n_layers,
        use_intrinsic_coords=use_intrinsic_coords,
        use_abs_coords=use_abs_coords,
        use_distances=use_distances,
        dropout=dropout,
    )

    # Create training state
    rng = jax.random.PRNGKey(0)
    dataset = Ala2Dataset(file_path, feat_type=feat_type, cg_level=cg_level)
    x_example = jnp.tile(dataset[0]['x'].reshape(1, -1), (batch_size,1))
    features_example = jnp.tile(dataset[0]['features'].reshape(1, -1), (batch_size,1))
    t_example = jnp.ones((batch_size, 1))
    state = create_ala2_train_state(
        rng=rng, 
        model=model, 
        example=
        {
            "x": x_example,
            "features": features_example,
            "t": t_example,
            "training": True,
        },
        config=config,
        ema_decay=ema_decay,
    )

    # Training loop
    train_loss = 0.0
    train_loss_set = []
    pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    for _ in pbar:
        for batch in dataloader:
            
            x_batch = jnp.asarray(batch["x"])
            x_batch = x_batch.reshape(x_batch.shape[0], -1)
            features = jnp.asarray(batch["features"])
            
            rng, rng1, rng2 = jax.random.split(rng, 3)
            x_init = jax.random.normal(rng1, x_batch.shape) * sigma + mu
            t = jax.random.uniform(rng2, shape=(batch_size,1), minval=t0, maxval=t1)

            loss, state = train_step(
                state=state,
                x=x_batch,
                x_init=x_init,
                t=t,
                features=features,
            )
            train_loss += loss

        train_loss /= len(dataloader)
        pbar.set_postfix({
            "loss": f"{train_loss:.6f}"
        })
        train_loss_set.append(train_loss)
        train_loss = 0.0

    train_loss_set = jnp.stack(train_loss_set)

    # Save parameters
    output_path = BASE_DIR / f"output/ala2/{task_id}"
    os.makedirs(output_path, exist_ok=True)
    config_file = output_path / "config.json"
    output_file = output_path / "final_params.npy"
    loss_file = output_path / "loss.npy"
    onp.save(output_file, state.params)
    onp.save(loss_file, train_loss_set)
    with open(config_file, "w") as f:
        json.dump(config, f, indent=2)
    print(f"Models saved to {output_file}")
    print(f"Training loss saved to {loss_file}")
    print(f"Config saved to {config_file}")
if __name__ == "__main__":
    train(args.task)

    