import argparse 
import os
import jax
import copy
import random
import pickle 
import numpy as np
import jax.numpy as jnp
from flax.core import freeze, unfreeze
from model import FlaxTransformer,FlaxTransformerMoE, ModelConfig, load_combined_model, print_model
from datamodule import flax_make_numpy_loader, load_agnews_data
from flax.serialization import to_state_dict, from_state_dict
from flax.traverse_util import flatten_dict, unflatten_dict
from engine import create_train_state, train_epoch, evaluate_model

def copy_weights(transformer_params, moe_params, config, train_classifier=False):
    transformer_params = unfreeze(transformer_params)
    moe_params = unfreeze(moe_params)
    # === Copy embeddings ===
    moe_params["encoder"]["tok_embed"]["embedding"] = transformer_params["encoder"]["tok_embed"]["embedding"]
    moe_params["encoder"]["pos_embed"] = transformer_params["encoder"]["pos_embed"]
    # === Copy encoder blocks ===
    for i in range(config.N_encoder):
        block_key_src = f"encoder_blocks_{i}"
        block_key_dst = f"blocks_{i}"
        if i != config.moe_idx:
            moe_params["encoder"][block_key_dst] = copy.deepcopy(transformer_params["encoder"][block_key_src])
        else:
            trans_block = transformer_params["encoder"][block_key_src]
            moe_block = moe_params["encoder"][block_key_dst]
            moe_block["residual1"] = copy.deepcopy(trans_block["residual1"])
            moe_block["residual2"] = copy.deepcopy(trans_block["residual2"])
            moe_block["attention"] = copy.deepcopy(trans_block["attention"])
            # Optionally initialize the first expert with FF weights from standard block
            # for j in range(config.num_gated_experts):
            #     moe_block["feed_forward"]["gated_experts"][f"{j}"]["linear1"] = copy.deepcopy(trans_block["feed_forward"]["linear1"])
            #     moe_block["feed_forward"]["gated_experts"][f"{j}"]["linear2"] = copy.deepcopy(trans_block["feed_forward"]["linear2"])
    moe_params["encoder"]["norm"] = transformer_params["encoder"]["norm"]
    moe_params["classifier"] = transformer_params["classifier"]
    return freeze(moe_params)
def make_param_mask(params, config, train_classifier=False):
    """
    Creates a label tree marking each parameter as 'trainable' or 'frozen'.
    - Trains only MoE experts and gate in the specified MoE layer (config.moe_idx).
    - Optionally trains the classifier head.
    """
    def label_fn(path, _):
        keys = [str(p.key) for p in path]
        # Match the MoE block
        if (len(keys) >= 4 
            and keys[0] == "encoder" 
            and keys[1] == f"blocks_{config.moe_idx}" 
            and keys[2] == "feed_forward"):
            return "trainable"
        if train_classifier and keys[0] == "classifier":
            return "trainable"
        return "frozen"
    return freeze(jax.tree_util.tree_map_with_path(label_fn, params))



def main():
    parser = argparse.ArgumentParser(description="Fine-tune BERT on IMDB review")
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--num-experts", type=int, required=True)
    parser.add_argument("--num-shared-experts", type=int, required=True)
    parser.add_argument("--num-gated-experts", type=int, required=True)
    parser.add_argument("--moe-idx", type=int, required=True)
    parser.add_argument("--topk", type=int, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--train-classifier", type=int, help="Enable training the classifier")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--patience", type=int, default=5, help="Early stopping patience")
    parser.add_argument("--save-dir", type=str, required=True)     # Path to save model
    args = parser.parse_args()
    seed = args.seed
    np.random.seed(seed)
    rng = jax.random.PRNGKey(seed)
    (x_train, y_train), (x_val, y_val), num_classes, PAD = load_agnews_data(args.data_path)
    train_loader = list(flax_make_numpy_loader(x_train, y_train, args.batch_size, pad_id=PAD))
    val_loader = list(flax_make_numpy_loader(x_val, y_val, args.batch_size, pad_id=PAD))
    # === Model Loading and Definition ===
    transformer_params, transformer_config = load_combined_model(args.model_path)
    config = ModelConfig(**transformer_config)
    config.num_experts = args.num_experts 
    config.moe_idx = args.moe_idx
    config.num_shared_experts = args.num_shared_experts
    config.num_gated_experts = args.num_gated_experts
    config.topk = args.topk 
    transformer_model = FlaxTransformer(config=config, num_classes=num_classes)
    moe_model = FlaxTransformerMoE(config=config, num_classes=num_classes)
    # Initialize variables to get shape structure
    dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
    dummy_mask = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
    variables = moe_model.init(rng, dummy_input, pad_mask=dummy_mask, deterministic=True)
    moe_params = variables["params"]
    print_model(moe_params)
    # print_model(transformer_params)
    # print_model(moe_params)
    # exit()
    moe_params = copy_weights(transformer_params, moe_params, config,train_classifier=(args.train_classifier == 1))
    mask = make_param_mask(moe_params, config,train_classifier=(args.train_classifier == 1))
    print(mask)
    # Create training state with copied parameters
    state = create_train_state(model=moe_model, learning_rate=args.lr, params=moe_params, mask = mask)
    state1 = create_train_state(model=moe_model, learning_rate=args.lr, params=moe_params, mask = mask)
    state2 = create_train_state(model=transformer_model, learning_rate=args.lr, params=transformer_params)
    val_metrics = evaluate_model(state1, val_loader)
    print(f"Val Loss={val_metrics['loss']:.4f}, Val Acc={val_metrics['accuracy']:.4f}")
    val_metrics = evaluate_model(state2, val_loader)
    print(f"Val Loss={val_metrics['loss']:.4f}, Val Acc={val_metrics['accuracy']:.4f}")
    # exit()
    best_val_loss = float("inf")
    best_params = None
    wait = 0
    for epoch in range(args.epochs):
        rng, input_rng = jax.random.split(rng)
        py_rng = random.Random(int(jax.random.randint(input_rng, (), 0, 1_000_000)))
        py_rng.shuffle(train_loader)
        state, train_metrics = train_epoch(state, train_loader, input_rng)
        val_metrics = evaluate_model(state, val_loader)
        print(f"Epoch {epoch}: "
              f"Train Loss={train_metrics['loss']:.4f}, Train Acc={train_metrics['accuracy']:.4f}, "
              f"Val Loss={val_metrics['loss']:.4f}, Val Acc={val_metrics['accuracy']:.4f}")
        # Early stopping
        if val_metrics["loss"] < best_val_loss:
            best_val_loss = val_metrics["loss"]
            best_params = to_state_dict(state.params)
            wait = 0
        else:
            wait += 1
            if wait >= args.patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break
    # Save best model
    bundle = {"flax_params": best_params, "config": config.__dict__}
    save_path = os.path.join(args.save_dir,f"idx-{args.moe_idx}-shared-{args.num_shared_experts}-gated-{args.num_gated_experts}-topk-{args.topk}-seed-{args.seed}.flax")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, "wb") as f:
        pickle.dump(bundle, f)
    print(f"Best model saved to {save_path}")
if __name__ == "__main__":
    main()