import argparse 
import os
import jax
import random
import pickle 
import jax.numpy as jnp
from flax.serialization import to_state_dict
from model import FlaxTransformer, ModelConfig
from datamodule import flax_make_numpy_loader, load_imdb_review_data
from engine import create_train_state, train_epoch, evaluate_model

def main():
    parser = argparse.ArgumentParser(description="Train Transformer on IMDB Review")
    parser.add_argument("--data-path", type=str, required=True)   # Path to imdbreview.csv
    parser.add_argument("--save-dir", type=str, required=True)     # Path to save model
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--num-layers", type=int, default=1)
    args = parser.parse_args()

    rng = jax.random.PRNGKey(args.seed)

    # ✅ Changed from load_agnews_data to load_imdb_review_data
    (x_train, y_train), (x_val, y_val), num_classes, PAD = load_imdb_review_data(args.data_path)
    train_loader = list(flax_make_numpy_loader(x_train, y_train, args.batch_size, pad_id=PAD))
    val_loader = list(flax_make_numpy_loader(x_val, y_val, args.batch_size, pad_id=PAD))
    # You might want to adjust config a little (longer text)
    config = ModelConfig(
        encoder_vocab_size=15000,
        d_embed=32,
        d_ff=128,
        h=1,
        N_encoder=args.num_layers,
        max_seq_len=256,  
        dropout=0.1,
        num_experts=0,
    )

    model = FlaxTransformer(config=config, num_classes=num_classes)
    dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
    variables = model.init(rng, dummy_input)
    params = variables["params"]

    state = create_train_state(model=model, learning_rate=args.lr, params=params)
    best_val_loss = float("inf")
    patience = 5
    wait = 0
    best_params = None

    for epoch in range(args.epochs):
        rng, input_rng = jax.random.split(rng)
        py_rng = random.Random(int(jax.random.randint(input_rng, (), 0, 1_000_000)))
        py_rng.shuffle(train_loader)

        state, train_metrics = train_epoch(state, train_loader, input_rng)
        val_metrics = evaluate_model(state, val_loader)

        print(f"Epoch {epoch}: "
              f"Train Loss={train_metrics['loss']:.4f}, Train Acc={train_metrics['accuracy']:.4f}, "
              f"Val Loss={val_metrics['loss']:.4f}, Val Acc={val_metrics['accuracy']:.4f}")

        if val_metrics["loss"] < best_val_loss:
            best_val_loss = val_metrics["loss"]
            best_params = to_state_dict(state.params)
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

    bundle = {
        "flax_params": best_params,
        "config": config.__dict__,
    }
    save_path = os.path.join(args.save_dir, f"transformer-layers-{args.num_layers}.flax")
    os.makedirs(args.save_dir,exist_ok=True)
    with open(save_path, "wb") as f:
        pickle.dump(bundle, f)
    print(f"Best model saved to {save_path}")

if __name__ == "__main__":
    main()
