import os
import sys
import argparse

# Add parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))

# Argument parser for device selection
parser = argparse.ArgumentParser()
parser.add_argument(
    '--device',
    type=str,
    help='GPU or MIG UUID to use for training'
)
args, _ = parser.parse_known_args()
print(f"Running on device: {args.device}")

# Set environment variables based on parsed arguments
if args.device:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9'

import jax
import jax.numpy as jnp
from tqdm import tqdm
import numpy as onp
import matplotlib.pyplot as plt
from rbf_mlp import RBFMLP
from collections import OrderedDict
from mb_data import get_mb_dataloader
from train_utils import train_step, create_train_state, evaluate_loss

def get_default_config():
    return OrderedDict(
        general=OrderedDict(
            cg_level="high", # "low" or "high"
            seed=0,
            mlp = True,
            T_transfer = False,
            kT0 = 1.0,
            kT1 = 50.0,
        ),
        dataset=OrderedDict(
            num_samples=200_000,
            mb_datafile="",
            shuffle=False,
            drop_last=True,
        ),
        model=OrderedDict(
            features=[128, 128, 128, 128],
        ),
        trainer=OrderedDict(
            name="adam",
            learning_rate=1e-4,
            weight_decay=0.0,
            epochs=500,
            batch=128,
        ),
    )

def train():
    config = get_default_config()
    seed = config["general"]["seed"]
    T_transfer = config["general"]["T_transfer"]
    kT1 = config["general"]["kT1"]
    num_samples = config['dataset']['num_samples']
    hidden_layers = config["model"]["features"]
    bs = config["trainer"]["batch"]
    lr = config["trainer"]["learning_rate"]
    epochs = config["trainer"]["epochs"]
    filepath = config['dataset']['mb_datafile']

    rng = jax.random.PRNGKey(seed)
    dataloader = get_mb_dataloader(batch_size=bs,
                                  filepath=filepath,
                                  num_samples=num_samples)
    model = RBFMLP(
        hidden_layers=hidden_layers,
        num_rbf_centers=100,   # You can expose this via config if needed
        sigma=5.0              # Tune this!
    )

    # Create train state
    state = create_train_state(rng, model, jnp.ones((bs, 1)), lr)
    # Training loop
    train_loss_set = []
    pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    for _ in pbar:
        for x_batch, mlp_batch in dataloader:
            x_batch = jnp.array(x_batch.numpy())
            fx_batch = jnp.array(mlp_batch.numpy())
            state = train_step(state, x_batch, fx_batch)
        train_loss = evaluate_loss(state, dataloader)
        train_loss_set.append(train_loss)
        pbar.set_postfix({
            "loss": f"{train_loss:.6f}"
        })
    loss = jnp.stack(train_loss_set)

    os.makedirs(f"output/{num_samples}", exist_ok=True)
    file_prefix = f"mb_model_biased"
    if T_transfer:
        file_prefix += f"_trans{int(kT1)}"
    with open(f"output/{num_samples}/{file_prefix}_{epochs}.pkl", "wb") as f:
        onp.save(f, state.params)
    with open(f"output/{num_samples}/{file_prefix}_loss_{epochs}.npy", "wb") as f:
        onp.save(f, loss)
    print("Model and loss saved")

    return state.params, loss

def inference(params, loss=None):
    config = get_default_config()
    T_transfer = config["general"]["T_transfer"]
    kT0 = config["general"]["kT0"]
    kT1 = config["general"]["kT1"]
    epochs = config["trainer"]["epochs"]
    num_samples = config['dataset']['num_samples']
    hidden_layers = config["model"]["features"]

    model = RBFMLP(
        hidden_layers=hidden_layers,
        num_rbf_centers=100,
        sigma=5.0
    )

    # trained energy fn
    def energy_fn(x):
        return model.apply(params, x)
    
    x_vals = jnp.linspace(0, 50, 300)
    y_vals = jnp.linspace(0, 50, 300)
    kT = kT1 if T_transfer else kT0
    x=x_vals.reshape(-1, 1)

    # numerical integrated energy fn
    def V(x, y):
        term1 = -17.3 * jnp.exp(-0.0039 * (x - 48)**2 - 0.0391 * (y - 8)**2)
        term2 = -8.7 * jnp.exp(-0.0039 * (x - 32)**2 - 0.0391 * (y - 16)**2)
        term3 = -14.7 * jnp.exp(-0.0254 * (x - 24)**2 + 0.043 * (x - 24) * (y - 32) - 0.0254 * (y - 32)**2)
        term4 = 1.3 * jnp.exp(0.00273 * (x - 16)**2 + 0.0023 * (x - 16) * (y - 24) + 0.00273 * (y - 24)**2)
        return term1 + term2 + term3 + term4
    
    def V_cg(xi):
        V_y = jax.vmap(lambda y: V(xi, y))(y_vals)
        return -kT * jnp.log(jnp.trapezoid(jnp.exp(- V_y/kT), y_vals))
    
    V_eff = jax.vmap(V_cg)(x_vals)
    V_eff -= jnp.min(V_eff)
    energy = energy_fn(x)
    energy -= jnp.min(energy)

    file_prefix = f"energy_biased"
    if T_transfer:
        file_prefix += f"_trans{int(kT1)}"

    fig, (ax, ax1) = plt.subplots(nrows=2, ncols=1, figsize=(10, 8))
    ax.plot(x, energy, lw=2, color="blue", label="trained energy")
    ax.plot(x, V_eff, lw=2, color="red", label="cg energy")
    ax.set_xlabel("x")
    ax.set_ylabel("Energy")
    ax.set_title("Energy vs x")
    ax.legend()
    ax1.plot(range(epochs), loss)
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.set_title("Loss vs Energy")
    plt.tight_layout()
    plt.savefig(f"output/{num_samples}/{file_prefix}_{epochs}.png")
    print("Figure Saved")

if __name__ == "__main__":
    params, loss = train()
    inference(params, loss)