import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import copy
import jax
import json
import jax.numpy as jnp
import optax
import torch
import wandb
import numpy as np
from datasets import Dataset
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from flax.core.frozen_dict import freeze, unfreeze
os.environ["WANDB_API_KEY"] = ""

# fmt: off
# fmt: on


def batch_collate_fn(data_list: List[Dict[str, Any]]) -> Dict[str, Any]:
    batch_dict = {key: [] for key in data_list[0].keys()}
    for data in data_list:
        for key, value in data.items():
            batch_dict[key].append(value)
    return shard({key: jnp.array(value) for key, value in batch_dict.items()})


def decay_mask_fn(params):
    flat_params = flatten_dict(params)
    flat_mask = {
        path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
        for path in flat_params
    }
    return unflatten_dict(flat_mask)

# ---------- Training Utilities ----------
def get_trainable_mask(params, config):
    """
    Returns a parameter mask where only the MoE block at config.moe_layer_indices
    (a single integer) is trainable; all other parameters are frozen.
    """
    def is_moe_param(keys):
        # Match MoE parameters in transformer/h/{moe_idx}/mlp/(gate|routed_experts_*)
        if len(keys) < 6:
            return False
        return (
            keys[0] == "transformer" and keys[1] == "h" and keys[2] == str(config.moe_layer_indices) and keys[3] == "mlp"
        )

    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        return "trainable" if is_moe_param(keys) else "frozen"

    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_parmas(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    # 1. Copy top-level embeddings and final layer norm
    finetune_params["transformer"]["wte"] = copy.deepcopy(pretrained_params["transformer"]["wte"])
    finetune_params["transformer"]["wpe"] = copy.deepcopy(pretrained_params["transformer"]["wpe"])
    finetune_params["transformer"]["ln_f"] = copy.deepcopy(pretrained_params["transformer"]["ln_f"])
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i  == config.moe_layer_indices:
            # Handle MoE layer: copy attention and norms from pretrained
            ref_layer = pretrained_params["transformer"]["h"][str_i]
            target_layer = finetune_params["transformer"]["h"][str_i]

            target_layer["ln_1"] = copy.deepcopy(ref_layer["ln_1"])
            target_layer["attn"] = copy.deepcopy(ref_layer["attn"])
            target_layer["ln_2"] = copy.deepcopy(ref_layer["ln_2"])

            # Copy MLP weights into all experts
            c_fc = ref_layer["mlp"]["c_fc"]
            c_proj = ref_layer["mlp"]["c_proj"]
            # for k in target_layer["mlp"]:
            #     if k.startswith("routed_experts_"):
            #         target_layer["mlp"][k]["c_fc"] = copy.deepcopy(c_fc)
            #         target_layer["mlp"][k]["c_proj"] = copy.deepcopy(c_proj)
        else:
            # Standard block, copy all directly
            finetune_params["transformer"]["h"][str_i] = copy.deepcopy(pretrained_params["transformer"]["h"][str_i])
    return freeze(finetune_params)


def main(args: argparse.Namespace):
    wandb.init(
        project="GPT2-Wikitext103",
        name=f"FinetuneMoE-idx{args.moe_layer_indices}-lr{args.learning_rate}-topk{args.topk}"
            f"-shared{args.num_shared_experts}-routed{args.num_routed_experts}-seed{args.seed}",
        save_code=True
    )
    args.model_save_dir = os.path.join(
        args.model_save_dir,
        f"lr{args.learning_rate}-topk{args.topk}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-seed{args.seed}"
    )
    os.makedirs(args.model_save_dir,exist_ok=True)
    wandb.config = dict(vars(args))
    train_dataset = Dataset.from_parquet(args.train_dataset_paths)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        collate_fn=batch_collate_fn,
    )
    eval_dataset = Dataset.from_parquet(args.eval_dataset_paths)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        collate_fn=batch_collate_fn,
    )
    config = GPT2Config.from_pretrained("gpt2")
    pretrained_model = FlaxGPT2LMHeadModel(config,input_shape=(args.batch_size, args.max_sequence_length),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    pretrained_model.params = pretrained_params
    config = copy.deepcopy(config)
    config.num_routed_experts = args.num_routed_experts
    config.num_shared_experts = args.num_shared_experts
    config.topk = args.topk
    config.moe_layer_indices = args.moe_layer_indices
    model = FlaxGPT2MoELMHeadModel(config,input_shape=(args.batch_size, args.max_sequence_length),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    model.config.save_pretrained(args.model_save_dir)
    print_model(model.params)
    model.params = pretrained2finetune_parmas(pretrained_model.params,model.params,config)
    num_train_steps = len(train_dataloader) * args.num_epochs
    label_mask = get_trainable_mask(model.params,config)
    print(json.dumps(label_mask, indent=2))
    linear_decay_lr_schedule_fn = optax.linear_schedule(
        init_value=args.learning_rate,
        end_value=0.0,
        transition_steps=num_train_steps,
    )
    # Define optimizer transformations
    tx = optax.multi_transform(
        transforms={
            'trainable': optax.adamw(
                learning_rate=linear_decay_lr_schedule_fn,
                b1=args.adamw_beta1,
                b2=args.adamw_beta2,
                eps=args.adamw_eps,
                weight_decay=args.weight_decay_rate,
                mask=decay_mask_fn,
            ),
            'frozen': optax.set_to_zero()
        },
        param_labels=label_mask  # should match model.params PyTree structure
    )
    # Initialize train state
    state = train_state.TrainState.create(apply_fn=model.__call__,params=model.params,tx=tx,)
    latest_train_step = state.step

    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, num=2)

        def loss_fn(params):
            labels = batch.pop("labels")
            pred_logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            return optax.softmax_cross_entropy(pred_logits, onehot(labels, pred_logits.shape[-1])).mean()

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad)
        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics, new_dropout_rng

    def eval_step(state, batch):
        labels = batch.pop("labels")
        pred_logits = model(**batch, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(pred_logits, onehot(labels, pred_logits.shape[-1]))
        metrics = {"eval_loss": loss.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    latest_train_step = state.step
    parallel_train_step = jax.pmap(train_step, "batch")
    parallel_eval_step = jax.pmap(eval_step, "batch")
    state = replicate(state)
    rng = jax.random.PRNGKey(args.seed)

    train_metrics_stack = []
    last_timestamp = time.time()
    eval_metrics = [parallel_eval_step(state, batch) for batch in eval_dataloader]
    eval_metrics = get_metrics(eval_metrics)
    eval_metrics = jax.tree_util.tree_map(jnp.mean, unreplicate(eval_metrics))
    eval_metrics["eval_ppl"] = jnp.exp(eval_metrics["eval_loss"])
    print(f"loss: {eval_metrics['eval_loss']:.4f} ppl: {eval_metrics['eval_ppl']:.2f}")
    for epoch in range(args.num_epochs):
        dropout_rngs = jax.random.split(rng, num=jax.local_device_count())

        # TODO: device prefetch
        for i, batch in enumerate(train_dataloader):
            current_train_step = len(train_dataloader) * epoch + i
            if args.restore_checkpoint_path and current_train_step <= latest_train_step:
                continue

            state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metric)

            if current_train_step > 0 and current_train_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)

                loss = train_metrics["loss"]
                ppl = jnp.exp(loss)

                duration = int(time.time() - last_timestamp)
                eta_secs = (num_train_steps - current_train_step) * duration // 50
                eta = timedelta(seconds=eta_secs)

                print(
                    f"[TRAIN] epoch: {epoch} step: {current_train_step}/{num_train_steps} "
                    f"loss: {loss:.4f} ppl: {ppl:.2f} ETA {eta}"
                )
                wandb.log({"loss": loss, "ppl": ppl, "epoch": epoch}, step=current_train_step)
                last_timestamp, train_metrics_stack = time.time(), []

            is_end_of_epoch = i + 1 == len(train_dataloader)
            if current_train_step > 0 and (current_train_step % args.eval_frequency == 0 or is_end_of_epoch):
                eval_metrics = [parallel_eval_step(state, batch) for batch in eval_dataloader]
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(jnp.mean, unreplicate(eval_metrics))
                eval_metrics["eval_ppl"] = jnp.exp(eval_metrics["eval_loss"])
                print(
                    f"[EVAL] epoch: {epoch} step: {current_train_step}/{num_train_steps} "
                    f"loss: {eval_metrics['eval_loss']:.4f} ppl: {eval_metrics['eval_ppl']:.2f}"
                )
                wandb.log(eval_metrics, step=current_train_step)

            if current_train_step > 0 and (current_train_step % args.save_frequency == 0 or is_end_of_epoch):
                checkpoints.save_checkpoint(
                    ckpt_dir=args.model_save_dir,
                    target=unreplicate(state),
                    step=current_train_step,
                    keep=3,
                )
                print(f"save checkpoint to {args.model_save_dir}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="distilgpt2", help="GPT2 config name (huggingface model hub)")
    parser.add_argument("--moe-layer-indices", type=int, required = True)
    parser.add_argument("--num-shared-experts", type=int, required = True)
    parser.add_argument("--num-routed-experts", type=int, required = True)
    parser.add_argument("--topk", type=int, required = True)
    parser.add_argument("--seed", type=int, required = True)
    parser.add_argument("--train-dataset-paths", type=str, default="dataset/wikitext.train**", help="train datset paths (multiple paths)")
    parser.add_argument("--eval-dataset-paths", type=str, default="dataset/wikitext.test**", help="eval dataset paths (multiple paths)")
    parser.add_argument("--batch-size", type=int, default=16, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument("--max-sequence-length", type=int, default=256, help="sequence lenght of model input")
    parser.add_argument("--num-epochs", type=int, default=5, help="number of epochs")
    parser.add_argument("--learning-rate", type=float, default=3e-5, help="learning rate")
    parser.add_argument("--weight-decay-rate", type=float, default=0.01, help="weight deacy rate for lr scheduler")
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.98)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32", help="model datatype")
    parser.add_argument("--logging-frequency", type=int, default=100, help="do logging every logging_frequency step")
    parser.add_argument("--eval-frequency", type=int, default=5000, help="do evalution every eval_frequency step")
    parser.add_argument("--save-frequency", type=int, default=5000, help="do saving checkpoint every save_frequencey step")
    parser.add_argument("--model-save-dir", type=str, default="artifacts/", help="checkpoint saving dir")
    parser.add_argument("--restore-checkpoint-path", type=str, help="if you want to restart from specific checkpoint, set this arg to checkpoint path")
    main(parser.parse_args())