import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import copy
import jax
import json
import jax.numpy as jnp
import optax
import wandb
import numpy as np
from datasets import Dataset
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from datamodule import  JAXTextDataset
from torch.utils.data import  DataLoader
from itertools import cycle
from flax.core.frozen_dict import freeze, unfreeze
from tqdm import tqdm 
os.environ["WANDB_API_KEY"] = ""

# fmt: off
# fmt: on


LOG2E = jnp.log2(jnp.e) 
def load_token_file(path):
    with open(path, 'r') as f:
        tokens = [int(t) for t in f.read().split()]
    return np.array(tokens, dtype=np.uint8)

# Use PyTorch DataLoader, returns batches of jax.numpy arrays
def jax_collate(batch):
    return shard({key: jnp.stack([item[key] for item in batch]) for key in batch[0]})

def decay_mask_fn(params):
    flat_params = flatten_dict(params)
    flat_mask = {
        path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
        for path in flat_params
    }
    return unflatten_dict(flat_mask)
# ---------- Training Utilities ----------
def get_trainable_mask(params, config):
    """
    Returns a parameter mask where only the MoE block at config.moe_layer_indices
    (a single integer) is trainable; all other parameters are frozen.
    """
    def is_moe_param(keys):
        # Match MoE parameters in transformer/h/{moe_idx}/mlp/(gate|routed_experts_*)
        if len(keys) < 6:
            return False
        return (
            keys[0] == "transformer" and keys[1] == "h" and keys[2] == str(config.moe_layer_indices) and keys[3] == "mlp"
        )

    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        return "trainable" if is_moe_param(keys) else "frozen"

    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_parmas(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    # 1. Copy top-level embeddings and final layer norm
    finetune_params["transformer"]["wte"] = copy.deepcopy(pretrained_params["transformer"]["wte"])
    finetune_params["transformer"]["wpe"] = copy.deepcopy(pretrained_params["transformer"]["wpe"])
    finetune_params["transformer"]["ln_f"] = copy.deepcopy(pretrained_params["transformer"]["ln_f"])
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i  == config.moe_layer_indices:
            # Handle MoE layer: copy attention and norms from pretrained
            ref_layer = pretrained_params["transformer"]["h"][str_i]
            target_layer = finetune_params["transformer"]["h"][str_i]

            target_layer["ln_1"] = copy.deepcopy(ref_layer["ln_1"])
            target_layer["attn"] = copy.deepcopy(ref_layer["attn"])
            target_layer["ln_2"] = copy.deepcopy(ref_layer["ln_2"])

            # Copy MLP weights into all experts
            c_fc = ref_layer["mlp"]["c_fc"]
            c_proj = ref_layer["mlp"]["c_proj"]
            # for k in target_layer["mlp"]:
            #     if k.startswith("routed_experts_"):
            #         target_layer["mlp"][k]["c_fc"] = copy.deepcopy(c_fc)
            #         target_layer["mlp"][k]["c_proj"] = copy.deepcopy(c_proj)
        else:
            # Standard block, copy all directly
            finetune_params["transformer"]["h"][str_i] = copy.deepcopy(pretrained_params["transformer"]["h"][str_i])
    return freeze(finetune_params)


def main(args: argparse.Namespace):
    wandb.init(project=args.wandb_project,name=f"FinetuneMoe-lr{args.learning_rate}-iter{args.num_iter}-size{args.batch_size}", save_code=True)
    args.model_save_dir = os.path.join(
        args.model_save_dir,
        f"lr{args.learning_rate}-topk{args.topk}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-seed{args.seed}"
    )
    os.makedirs(args.model_save_dir,exist_ok=True)
    wandb.config = dict(vars(args))
    train_data = load_token_file(os.path.join(args.data_path,'train.txt'))
    val_data   = load_token_file(os.path.join(args.data_path,'valid.txt'))
    test_data  = load_token_file(os.path.join(args.data_path,'test.txt'))
    # Create loaders
    train_loader = DataLoader(JAXTextDataset(train_data, args.max_sequence_length), batch_size=args.batch_size, shuffle=True, collate_fn=jax_collate)
    val_loader   = DataLoader(JAXTextDataset(val_data, args.max_sequence_length), batch_size=args.batch_size, collate_fn=jax_collate)
    test_loader  = DataLoader(JAXTextDataset(test_data, args.max_sequence_length), batch_size=args.batch_size, collate_fn=jax_collate)
    model_config = GPT2Config.from_pretrained('distilgpt2')
    model_config.vocab_size = 256                   # enwik8 = 256 ASCII tokens
    model_config.n_positions = args.max_sequence_length
    model_config.n_ctx = args.max_sequence_length
    model_config.bos_token_id = None                # not used
    model_config.eos_token_id = None                # not used
    pretrained_model = FlaxGPT2LMHeadModel(
        model_config,input_shape=(args.batch_size, args.max_sequence_length),seed=0,dtype=jnp.dtype(args.dtype),
    )
    pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    # print_model(pretrained_params)
    pretrained_model.params = pretrained_params
    config = copy.deepcopy(model_config)
    config.num_routed_experts = args.num_routed_experts
    config.num_shared_experts = args.num_shared_experts
    config.topk = args.topk
    config.moe_layer_indices = args.moe_layer_indices
    model = FlaxGPT2MoELMHeadModel(config,input_shape=(args.batch_size, args.max_sequence_length),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    model.config.save_pretrained(args.model_save_dir)
    model.params = pretrained2finetune_parmas(pretrained_model.params,model.params,config)
    num_train_steps = args.nb_per_iter * args.num_iter
    linear_decay_lr_schedule_fn = optax.linear_schedule(
        init_value=args.learning_rate, end_value=0, transition_steps=num_train_steps
    )
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=args.adamw_beta1,
        b2=args.adamw_beta2,
        eps=args.adamw_eps,
        weight_decay=args.weight_decay_rate,
        mask=decay_mask_fn,
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

    if args.restore_checkpoint_path:
        state = checkpoints.restore_checkpoint(args.restore_checkpoint_path, state)
        print(f"train state restored from {args.restore_checkpoint_path}")
        print(f"skip trian step to {state.step}")
    latest_train_step = state.step
    def train_step(state: train_state.TrainState, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, num=2)
        def loss_fn(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, axis_name="batch")
        new_state = state.apply_gradients(grads=grad)
        bpc = loss * LOG2E
        metrics = {"loss": loss,"bpc": bpc,"learning_rate": linear_decay_lr_schedule_fn(state.step),}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return new_state, metrics, new_dropout_rng
    def eval_step(state, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        bpc = loss * LOG2E
        metrics = {"eval_loss": loss,"eval_bpc": bpc,}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, axis_name="batch")
    parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
    state = replicate(state)
    rng = jax.random.PRNGKey(args.seed)
    train_metrics_stack = []
    last_timestamp = time.time()
    train_loader_iter = cycle(train_loader)
    eval_metrics = [parallel_eval_step(state, batch) for batch in val_loader]
    eval_metrics = get_metrics(eval_metrics)
    eval_metrics = jax.tree_util.tree_map(jnp.mean, unreplicate(eval_metrics))
    print(f"[EVAL] eval_loss: {eval_metrics['eval_loss']:.4f} eval_bpc: {eval_metrics['eval_bpc']:.4f}")
    for iteration in range(args.num_iter):
        dropout_rngs = jax.random.split(rng, num=jax.local_device_count())
        for i in tqdm(range(args.nb_per_iter), desc=f"Iteration {iteration}"):
            batch = next(train_loader_iter)
            current_train_step = iteration * args.nb_per_iter + i
            state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metric)
            if current_train_step > 0 and current_train_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)
                loss = train_metrics["loss"]
                bpc = train_metrics["bpc"]
                duration = int(time.time() - last_timestamp)
                eta_secs = (num_train_steps - current_train_step) * duration // 50
                eta = timedelta(seconds=eta_secs)
                print(f"[TRAIN] iter: {iteration} step: {current_train_step}/{num_train_steps} loss: {loss:.4f} bpc: {bpc:.4f} ETA {eta}")
                wandb.log({"loss": loss, "bpc": bpc, "iteration": iteration}, step=current_train_step)
                last_timestamp, train_metrics_stack = time.time(), []
            is_end_of_iter = i + 1 == args.nb_per_iter
            if current_train_step > 0 and is_end_of_iter:
                eval_metrics = [parallel_eval_step(state, batch) for batch in val_loader]
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(jnp.mean, unreplicate(eval_metrics))
                print(f"[EVAL] iter: {iteration} step: {current_train_step}/{num_train_steps} eval_loss: {eval_metrics['eval_loss']:.4f} eval_bpc: {eval_metrics['eval_bpc']:.4f}")
                wandb.log(eval_metrics, step=current_train_step)

            if current_train_step > 0 and (current_train_step % args.save_frequency == 0 or is_end_of_iter):
                checkpoints.save_checkpoint(ckpt_dir=args.model_save_dir, target=unreplicate(state), step=current_train_step, keep=3)
                print(f"Saved checkpoint to {args.model_save_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="distilgpt2", help="GPT2 config name (huggingface model hub)")
    parser.add_argument("--moe-layer-indices", type=int, required = True)
    parser.add_argument("--num-shared-experts", type=int, required = True)
    parser.add_argument("--num-routed-experts", type=int, required = True)
    parser.add_argument("--topk", type=int, required = True)
    parser.add_argument("--seed", type=int, required = True)
    parser.add_argument("--data-path", type=str, default="data/enwik8", help="train datset paths (multiple paths)")
    parser.add_argument("--batch-size", type=int, default=48, help="train, eval batch size (batch size will be devided by device count)")
    parser.add_argument("--max-sequence-length", type=int, default=256, help="sequence lenght of model input")
    parser.add_argument("--num-iter", type=int, default=40, help="number of training iterations")
    parser.add_argument("--nb-per-iter", type=int, default=1000, help="number of batches per iteration")
    parser.add_argument("--learning-rate", type=float, default=7e-4, help="learning rate")
    parser.add_argument("--weight-decay-rate", type=float, default=0.01, help="weight deacy rate for lr scheduler")
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32", help="model datatype")
    parser.add_argument("--logging-frequency", type=int, default=100, help="do logging every logging_frequency step")
    parser.add_argument("--eval-frequency", type=int, default=5000, help="do evalution every eval_frequency step")
    parser.add_argument("--save-frequency", type=int, default=5000, help="do saving checkpoint every save_frequencey step")
    parser.add_argument("--model-save-dir", type=str, default="/root/weights/enwik8/finetune", help="checkpoint saving dir")
    parser.add_argument("--restore-checkpoint-path", type=str, help="if you want to restart from specific checkpoint, set this arg to checkpoint path")
    parser.add_argument("--wandb-project", default="GPT2-Enwik8", help="wandb project name for logging")
    main(parser.parse_args())