import argparse 
import os
import jax
import random
import pickle 
import jax.numpy as jnp
from flax.serialization import to_state_dict
from model import FlaxTransformer, ModelConfig
from datamodule import flax_make_numpy_loader, load_agnews_data
from engine import create_train_state, train_epoch, evaluate_model

def main():
    parser = argparse.ArgumentParser(description="Train Transformer on AGNEWS")
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--save-dir", type=str, required=True)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--num-layers", type=int, default=1)
    args = parser.parse_args()
    rng = jax.random.PRNGKey(args.seed)
    (x_train, y_train), (x_val, y_val), num_classes, PAD = load_agnews_data(args.data_path)
    train_loader = list(flax_make_numpy_loader(x_train, y_train, args.batch_size, pad_id=PAD))
    val_loader = list(flax_make_numpy_loader(x_val, y_val, args.batch_size, pad_id=PAD))
    config = ModelConfig(
        encoder_vocab_size=15000,d_embed=32,d_ff=128,h=1,N_encoder=args.num_layers,max_seq_len=100,dropout=0.1,
        moe_idx = -1, num_experts = 0, num_shared_experts = 0, num_gated_experts = 0, topk = 0,
    )
    model = FlaxTransformer(config=config, num_classes=num_classes)
    dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
    variables = model.init(rng, dummy_input)
    params = variables["params"]
    state = create_train_state(model =model,learning_rate = args.lr,params=params)
    best_val_loss = float("inf")
    patience = 5
    wait = 0
    best_params = None
    for epoch in range(args.epochs):
        rng, input_rng = jax.random.split(rng)
        py_rng = random.Random(int(jax.random.randint(input_rng, (), 0, 1_000_000)))
        py_rng.shuffle(train_loader)
        state, train_metrics = train_epoch(state, train_loader, input_rng)
        val_metrics = evaluate_model(state, val_loader)
        print(f"Epoch {epoch}: "
            f"Train Loss={train_metrics['loss']:.4f}, Train Acc={train_metrics['accuracy']:.4f}, "
            f"Val Loss={val_metrics['loss']:.4f}, Val Acc={val_metrics['accuracy']:.4f}")
        if val_metrics["loss"] < best_val_loss:
            best_val_loss = val_metrics["loss"]
            best_params = to_state_dict(state.params)  # Save best params
            wait = 0  # Reset wait counter
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break
    bundle = {"flax_params": best_params,"config": config.__dict__,}
    save_path = os.path.join(args.save_dir,f"transformer-layers-{args.num_layers}.flax") 
    os.makedirs(args.save_dir,exist_ok=True)
    with open(save_path, "wb") as f:
        pickle.dump(bundle, f)
    print(f"Best model saved to {save_path}")
if __name__ == "__main__":
    main()