import argparse
import os
import jax
import json
import time
import math
import copy
import optax
import torch
import wandb
import numpy as np
import itertools
from tqdm import tqdm
import jax.numpy as jnp
from datasets import Dataset
from datetime import timedelta
from typing import Any, Dict, List
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from model import print_model, FlaxGPT2MoELMHeadModel
from data_utils import get_lm_corpus
from flax.core.frozen_dict import freeze, unfreeze

os.environ["WANDB_API_KEY"] = ""

# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--model-config-name", type=str, default="gpt2", help="GPT2 config name (huggingface model hub)")
parser.add_argument("--model-path", type=str, default="", help="Path of Pretrained Model")
parser.add_argument("--random-seed", type=int, default=0, help="random seed for RNG state")
parser.add_argument("--moe-layer-indices", type=int, required = True)
parser.add_argument("--num-shared-experts", type=int, required = True)
parser.add_argument("--num-routed-experts", type=int, required = True)
parser.add_argument("--topk", type=int, required = True)
parser.add_argument("--data-path", type=str, default="./data/lm1b", help="train datset paths (multiple paths)")
parser.add_argument('--dataset', type=str, default='lm1b',choices=['wt103', 'lm1b', 'enwik8', 'text8'],help='dataset name')
parser.add_argument('--max_step', type=int, default=80000,help='upper epoch limit')
parser.add_argument("--batch-size", type=int, default=48, help="train, eval batch size (batch size will be devided by device count)")
parser.add_argument('--tgt_len', type=int, default=256,help='number of tokens to predict')
parser.add_argument('--eval_tgt_len', type=int, default=256,help='number of tokens to predict for evaluation')
parser.add_argument('--ext_len', type=int, default=0,help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,help='length of the retained previous heads')
parser.add_argument("--learning-rate", type=float, default=0.00025, help="learning rate")
parser.add_argument("--weight-decay-rate", type=float, default=0.01, help="weight deacy rate for lr scheduler")
parser.add_argument('--eta_min', type=float, default=1.0e-8,help='min learning rate for cosine scheduler')
parser.add_argument("--adamw-beta1", type=float, default=0.9)
parser.add_argument("--adamw-beta2", type=float, default=0.999)
parser.add_argument("--adamw-eps", type=float, default=1e-8)
parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32", help="model datatype")
parser.add_argument("--wandb-username", default="codertimo", help="wandb username for logging")
parser.add_argument("--wandb-project", default="GPT2-OneBillionWord", help="wandb project name for logging")
parser.add_argument("--wandb-run-dir", default=".wandb", help="wandb run dir")
parser.add_argument("--logging-frequency", type=int, default=200, help="do logging every logging_frequency step")
parser.add_argument("--eval-frequency", type=int, default=4000, help="do evalution every eval_frequency step")
parser.add_argument("--save-frequency", type=int, default=4000, help="do saving checkpoint every save_frequencey step")
parser.add_argument("--model-save-dir", type=str, default="artifacts/", help="checkpoint saving dir")
parser.add_argument("--restore-checkpoint-path", type=str, help="if you want to restart from specific checkpoint, set this arg to checkpoint path")
# fmt: on



def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    """
    Convert and shard a language modeling batch from PyTorch to JAX.
    Args:
        data (torch.Tensor): Input data of shape (seq_len, batch)
        target (torch.Tensor): Target data of shape (seq_len, batch)
    Returns:
        Dict[str, jnp.ndarray]: Dict with 'data' and 'target', both sharded
            with shape (n_devices, batch_per_device, seq_len)
    """
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': shard(input_ids),'target': shard(target)}

def decay_mask_fn(params):
    flat_params = flatten_dict(params)
    flat_mask = {
        path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
        for path in flat_params
    }
    return unflatten_dict(flat_mask)
# ---------- Training Utilities ----------
def get_trainable_mask(params, config):
    """
    Returns a parameter mask where only the MoE block at config.moe_layer_indices
    (a single integer) is trainable; all other parameters are frozen.
    """
    def is_moe_param(keys):
        # Match MoE parameters in transformer/h/{moe_idx}/mlp/(gate|routed_experts_*)
        if len(keys) < 6:
            return False
        return (
            keys[0] == "transformer" and keys[1] == "h" and keys[2] == str(config.moe_layer_indices) and keys[3] == "mlp"
        )

    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        return "trainable" if is_moe_param(keys) else "frozen"

    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_parmas(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    # 1. Copy top-level embeddings and final layer norm
    finetune_params["transformer"]["wte"] = copy.deepcopy(pretrained_params["transformer"]["wte"])
    finetune_params["transformer"]["wpe"] = copy.deepcopy(pretrained_params["transformer"]["wpe"])
    finetune_params["transformer"]["ln_f"] = copy.deepcopy(pretrained_params["transformer"]["ln_f"])
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i  == config.moe_layer_indices:
            # Handle MoE layer: copy attention and norms from pretrained
            ref_layer = pretrained_params["transformer"]["h"][str_i]
            target_layer = finetune_params["transformer"]["h"][str_i]

            target_layer["ln_1"] = copy.deepcopy(ref_layer["ln_1"])
            target_layer["attn"] = copy.deepcopy(ref_layer["attn"])
            target_layer["ln_2"] = copy.deepcopy(ref_layer["ln_2"])

            # Copy MLP weights into all experts
            c_fc = ref_layer["mlp"]["c_fc"]
            c_proj = ref_layer["mlp"]["c_proj"]
            # for k in target_layer["mlp"]:
            #     if k.startswith("routed_experts_"):
            #         target_layer["mlp"][k]["c_fc"] = copy.deepcopy(c_fc)
            #         target_layer["mlp"][k]["c_proj"] = copy.deepcopy(c_proj)
        else:
            # Standard block, copy all directly
            finetune_params["transformer"]["h"][str_i] = copy.deepcopy(pretrained_params["transformer"]["h"][str_i])
    return freeze(finetune_params)

def main(args: argparse.Namespace):
    os.makedirs(args.wandb_run_dir, exist_ok=True)
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    wandb.init(
        project=args.wandb_project,
        name=f"FinetuneMoE-idx{args.moe_layer_indices}-lr{args.learning_rate}-topk{args.topk}"
            f"-shared{args.num_shared_experts}-routed{args.num_routed_experts}-batch{args.batch_size}-seed{args.random_seed}",
        save_code=True
    )
    wandb.config = dict(vars(args))
    save_path = os.path.join(
        args.model_save_dir,
        f"lr{args.learning_rate}-topk{args.topk}-shared{args.num_shared_experts}-routed{args.num_routed_experts}-batch{args.batch_size}-seed{args.random_seed}"
    )
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    args.n_token = ntokens
    eval_batch_size = 10
    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    model_config = GPT2Config.from_pretrained(args.model_config_name)
    model_config.vocab_size = args.n_token
    model_config.n_positions = args.tgt_len
    model_config.n_ctx = args.tgt_len
    pretrained_model = FlaxGPT2LMHeadModel(model_config,input_shape=(1, args.tgt_len),seed=0,dtype=jnp.dtype(args.dtype),)
    pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    pretrained_model.params = pretrained_params
    config = copy.deepcopy(model_config)
    config.num_routed_experts = args.num_routed_experts
    config.num_shared_experts = args.num_shared_experts
    config.topk = args.topk
    config.moe_layer_indices = args.moe_layer_indices
    model = FlaxGPT2MoELMHeadModel(config,input_shape=(1, args.tgt_len),seed=args.random_seed,dtype=jnp.dtype(args.dtype),)
    model.config.save_pretrained(save_path)
    print_model(model.params)
    model.params = pretrained2finetune_parmas(pretrained_model.params,model.params,config)
    # model = pretrained_model
    label_mask = get_trainable_mask(model.params,config)
    print(json.dumps(label_mask, indent=2))
    num_train_steps = args.max_step
    lr_schedule =optax.warmup_cosine_decay_schedule(init_value=0.0,peak_value=args.learning_rate,warmup_steps=2000,decay_steps=args.max_step,end_value=args.eta_min,)
    tx = optax.multi_transform(
        transforms={
            'trainable': optax.adamw(
                learning_rate=lr_schedule,
                b1=args.adamw_beta1,
                b2=args.adamw_beta2,
                eps=args.adamw_eps,
                weight_decay=args.weight_decay_rate
            ),
            'frozen': optax.set_to_zero()
        },
        param_labels=label_mask  
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)
    latest_train_step = state.step
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        def loss_fn(params):
            labels = batch.pop("target")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grads = grad_fn(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        new_state = state.apply_gradients(grads=grads)
        metrics = {"loss": loss,"learning_rate": lr_schedule(state.step)}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return new_state, metrics, new_dropout_rng
    def eval_step(state, batch):
        labels = batch.pop("target")
        logits = model(**batch, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        metrics = {"eval_loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, "batch")
    parallel_eval_step = jax.pmap(eval_step, "batch")
    state = replicate(state)
    rng = jax.random.PRNGKey(args.random_seed)
    train_metrics_stack = []
    train_step = int(jax.device_get(unreplicate(state.step)))
    train_loss = 0.0
    best_val_loss = float("inf")
    log_start_time = time.time()
    eval_start_time = time.time()
    ###JUST FOR TESTING####
    eval_results = []
    for eval_data, eval_target, _ in va_iter:
        eval_batch = prepare_lm_batch(eval_data, eval_target)
        eval_metric = parallel_eval_step(state, eval_batch)
        eval_results.append(eval_metric)
    eval_metrics = get_metrics(eval_results)
    eval_metrics = unreplicate(eval_metrics)
    eval_metrics = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)
    val_loss = float(eval_metrics["eval_loss"])
    val_ppl = math.exp(val_loss)
    print("-" * 100)
    print(
        f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
        f"time: {time.time() - eval_start_time:5.2f}s | "
        f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:9.3f}"
    )
    print("-" * 100)
    #### START FINETUNING ####
    print("Starting Finetuning...")
    print(f"JAX devices: {jax.devices()}")
    print(f"Using {jax.local_device_count()} devices")
    for epoch in itertools.count(start=1):
        print(f"Epoch {epoch}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_iter = tr_iter.get_varlen_iter() if getattr(args, "varlen", False) else tr_iter
        train_metrics_stack = []
        for batch_idx, (data, target, seq_len) in enumerate(tqdm(train_iter)):
            if train_step >= args.max_step:
                break
            # Prepare and shard batch
            batch = prepare_lm_batch(data, target)
            # Run train step
            state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metric)
            train_step += 1
            # Logging
            if train_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(lambda x: x.mean(), train_metrics)
                train_metrics_stack = []
                loss = float(train_metrics["loss"])
                ppl = math.exp(loss)
                curr_lr = float(lr_schedule(train_step))
                elapsed = time.time() - log_start_time
                print(
                    f"| epoch {epoch:3d} step {train_step:8d} | "
                    f"{batch_idx+1:6d} batches | lr {curr_lr:.3g} "
                    f"| ms/batch {elapsed * 1000 / args.logging_frequency:5.2f} | "
                    f"loss {loss:5.2f} | ppl {ppl:9.3f}"
                )
                wandb.log({"loss": loss,"ppl": ppl,"learning_rate": curr_lr}, step=train_step)
                log_start_time = time.time()
            # Evaluation
            if train_step % args.eval_frequency == 0:
                eval_results = []
                for eval_data, eval_target, _ in va_iter:
                    eval_batch = prepare_lm_batch(eval_data, eval_target)
                    eval_metric = parallel_eval_step(state, eval_batch)
                    eval_results.append(eval_metric)
                eval_metrics = get_metrics(eval_results)
                eval_metrics = unreplicate(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)

                val_loss = float(eval_metrics["eval_loss"])
                val_ppl = math.exp(val_loss)
                print("-" * 100)
                print(
                    f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                    f"time: {time.time() - eval_start_time:5.2f}s | "
                    f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:9.3f}"
                )
                print("-" * 100)
                wandb.log({"eval_loss": val_loss,"eval_ppl": val_ppl}, step=train_step)
                # Save best checkpoint
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=train_step,prefix="best_",keep=1)                    
                    print(f"Best model saved at step {train_step}")
                eval_start_time = time.time()
            # Periodic checkpoint
            if train_step % args.save_frequency == 0:
                checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=train_step,prefix="last_",keep=1)                    
                print(f"Checkpoint saved at step {save_path}")
        if train_step >= args.max_step:
            print("-" * 100)
            print("End of training")
            break


if __name__ == "__main__":
    main(parser.parse_args())