import os
import math
import time
import copy
import json
import torch
import jax
import argparse
import optax
import wandb
import shutil
import itertools
from tqdm import tqdm
from typing import Any, Dict, List
from copy import deepcopy
from datasets import Dataset
from datetime import timedelta
from data_utils import get_lm_corpus
from flax.jax_utils import replicate, unreplicate
from flax.core.frozen_dict import freeze, unfreeze
from flax.training import train_state, checkpoints
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.training.common_utils import get_metrics, onehot, shard
from transformers.models.llama.configuration_llama import LlamaConfig
from lmc_model import LMCFlaxLlamaForCausalLM, print_model
from data_utils import get_lm_corpus
import jax.numpy as jnp
import numpy as np
def remove_old_dirs_with_prefix(save_path, prefix, keep_step):
    for fname in os.listdir(save_path):
        if fname.startswith(prefix) and not fname.endswith(str(keep_step)):
            full_path = os.path.join(save_path, fname)
            if os.path.isdir(full_path):
                shutil.rmtree(full_path)
def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    """
    Convert and shard a language modeling batch from PyTorch to JAX.
    Args:
        data (torch.Tensor): Input data of shape (seq_len, batch)
        target (torch.Tensor): Target data of shape (seq_len, batch)
    Returns:
        Dict[str, jnp.ndarray]: Dict with 'data' and 'target', both sharded
            with shape (n_devices, batch_per_device, seq_len)
    """
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': shard(input_ids),'target': shard(target)}
 
# ---------- Training Utilities ----------
def get_trainable_mask(params, config):
    def is_trainable_param(keys):
        # Match MoE parameters in transformer/h/{moe_idx}/mlp/(gate|routed_experts_*)
        if len(keys) < 4: return False
        if keys[0] == "model" and keys[1] == "layers" and int(keys[2]) in config.lmc_layer_indices and keys[3] == "self_attn": return True
        if config.finetune_mlp == True and keys[0] == "model" and keys[1] == "layers" and int(keys[2]) in config.lmc_layer_indices and keys[3] == "mlp": return True
        return False
    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        return "trainable" if is_trainable_param(keys) else "frozen"
    return jax.tree_util.tree_map_with_path(label_fn, params)
def pretrained2finetune_params(pretrained_params, finetune_params, config):
    pretrained_params = unfreeze(pretrained_params)
    finetune_params = unfreeze(finetune_params)
    # 1. Copy top-level embeddings and final layer norm
    finetune_params["model"]["embed_tokens"] = copy.deepcopy(pretrained_params["model"]["embed_tokens"])
    finetune_params["model"]["norm"] = copy.deepcopy(pretrained_params["model"]["norm"])
    finetune_params["lm_head"] = copy.deepcopy(pretrained_params["lm_head"])
    # 2. Copy encoder layers
    for i in range(config.num_hidden_layers):
        str_i = str(i)
        if i in config.lmc_layer_indices:
            # Handle MoE layer: copy attention and norms from pretrained
            ref_layer = copy.deepcopy(pretrained_params["model"]["layers"][str_i])
            target_layer = finetune_params["model"]["layers"][str_i]
            target_layer["input_layernorm"] = copy.deepcopy(ref_layer["input_layernorm"])
            # target_layer["self_attn"] = copy.deepcopy(ref_layer["self_attn"])
            target_layer["post_attention_layernorm"] = copy.deepcopy(ref_layer["post_attention_layernorm"])
            if config.finetune_mlp == False:
                target_layer["mlp"] = copy.deepcopy(ref_layer["mlp"])
        else:
            # Standard block, copy all directly
            finetune_params["model"]["layers"][str_i] = copy.deepcopy(pretrained_params["model"]["layers"][str_i])
    return freeze(finetune_params)

def main(args: argparse.Namespace):
    if os.path.exists(args.model_path):
        config = LlamaConfig.from_pretrained(os.path.dirname(args.model_path))
        pretrained_model = LMCFlaxLlamaForCausalLM(config)
    else:
        raise FileNotFoundError(f"Config directory does not exist: {os.path.dirname(args.model_path)}")
    os.makedirs(args.wandb_run_dir, exist_ok=True)
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        name=f"finetune-indice{','.join(str(i) for i in args.lmc_layer_indices)}-heads{args.num_attention_heads}-"
        f"mlp{str(args.finetune_mlp)}-type{config.mlp_type}-seed{args.seed}",
        save_code=True
    )
    os.makedirs(args.wandb_run_dir, exist_ok=True)
    run_name = wandb.run.name
    wandb.config.update(vars(args), allow_val_change=True)
    save_path = os.path.join(args.model_save_dir, run_name)
    os.makedirs(save_path, exist_ok=True)
    # Seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # Data
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    args.n_token = ntokens
    eval_batch_size = 12
    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    lmc_config = copy.deepcopy(config)
    # Model config
    pretrained_model = LMCFlaxLlamaForCausalLM(config)
    pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    pretrained_model.params = pretrained_params
    lmc_config = copy.deepcopy(config)
    lmc_config.num_attention_heads = args.num_attention_heads
    lmc_config.num_key_value_heads = args.num_key_value_heads
    lmc_config.head_dim = args.head_dim
    config.lmc_config = lmc_config
    config.finetune_mlp = args.finetune_mlp
    config.lmc_layer_indices = args.lmc_layer_indices
    # --- Initialize fine-tuning model ---
    model = LMCFlaxLlamaForCausalLM(config,input_shape=(1, args.tgt_len),seed=args.seed,dtype=jnp.dtype(args.dtype),)
    model.config.save_pretrained(save_path)
    print_model(model.params)
    model.params = pretrained2finetune_params(pretrained_model.params,model.params,config)
    # model = pretrained_model
    label_mask = get_trainable_mask(model.params,config)
    print(json.dumps(label_mask, indent=2))
    num_train_steps = args.max_step
    lr_schedule =optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=args.learning_rate,
        warmup_steps=args.warmup_step,
        decay_steps=args.max_step,
        end_value=args.eta_min,
    )    
    tx = optax.multi_transform(
        transforms={
            'trainable': optax.adamw(
                learning_rate=lr_schedule,
                b1=args.adamw_beta1,
                b2=args.adamw_beta2,
                eps=args.adamw_eps,
                weight_decay=args.weight_decay_rate
            ),
            'frozen': optax.set_to_zero()
        },
        param_labels=label_mask  
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)
    latest_train_step = state.step
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        def loss_fn(params):
            labels = batch.pop("target")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grads = grad_fn(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        new_state = state.apply_gradients(grads=grads)
        metrics = {"loss": loss,"learning_rate": lr_schedule(state.step)}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return new_state, metrics, new_dropout_rng
    def eval_step(state, batch):
        labels = batch.pop("target")
        logits = model(**batch, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        metrics = {"eval_loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, "batch")
    parallel_eval_step = jax.pmap(eval_step, "batch")
    state = replicate(state)
    rng = jax.random.PRNGKey(args.seed)
    train_metrics_stack = []
    train_step = int(jax.device_get(unreplicate(state.step)))
    train_loss = 0.0
    best_val_loss = float("inf")
    log_start_time = time.time()
    eval_start_time = time.time()
    # ###JUST FOR TESTING####
    eval_results = []
    for eval_data, eval_target, _ in va_iter:
        eval_batch = prepare_lm_batch(eval_data, eval_target)
        eval_metric = parallel_eval_step(state, eval_batch)
        eval_results.append(eval_metric)
    eval_metrics = get_metrics(eval_results)
    eval_metrics = unreplicate(eval_metrics)
    eval_metrics = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)
    val_loss = float(eval_metrics["eval_loss"])
    val_ppl = math.exp(val_loss)
    print("-" * 100)
    print(
        f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
        f"time: {time.time() - eval_start_time:5.2f}s | "
        f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:9.3f}"
    )
    print("-" * 100)
    # #### START FINETUNING ####
    print("Starting training...")
    print(f"JAX devices: {jax.devices()}")
    print(f"Using {jax.local_device_count()} devices")
    for epoch in itertools.count(start=1):
        print(f"Epoch {epoch}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_iter = tr_iter.get_varlen_iter() if getattr(args, "varlen", False) else tr_iter
        train_metrics_stack = []
        for batch_idx, (data, target, seq_len) in enumerate(tqdm(train_iter)):
            if train_step >= args.max_step:
                break
            # Prepare and shard batch
            batch = prepare_lm_batch(data, target)
            # Run train step
            state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metric)
            train_step += 1
            # Logging
            if train_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(lambda x: x.mean(), train_metrics)
                train_metrics_stack = []
                loss = float(train_metrics["loss"])
                ppl = math.exp(loss)
                bpc = loss/math.log(2)
                curr_lr = float(lr_schedule(train_step))
                elapsed = time.time() - log_start_time
                if(args.dataset in ["wt103","lm1b"]):
                    print(
                        f"| epoch {epoch:3d} step {train_step:8d} | "
                        f"{batch_idx+1:6d} batches | lr {curr_lr:.3g} "
                        f"| ms/batch {elapsed * 1000 / args.logging_frequency:5.2f} | "
                        f"loss {loss:5.2f} | ppl {ppl:9.3f}"
                    )
                    wandb.log({"loss": loss,"ppl": ppl,"learning_rate": curr_lr}, step=train_step)
                elif(args.dataset in ["enwik8","text8"]):
                    print(
                        f"| epoch {epoch:3d} step {train_step:8d} | "
                        f"{batch_idx+1:6d} batches | lr {curr_lr:.3g} "
                        f"| ms/batch {elapsed * 1000 / args.logging_frequency:5.2f} | "
                        f"loss {loss:5.2f} | bpc {bpc:9.3f}"
                    )
                    wandb.log({"loss": loss,"bpc": bpc,"learning_rate": curr_lr}, step=train_step)
                log_start_time = time.time()
            # Evaluation
            if train_step % args.eval_frequency == 0:
                eval_results = []
                for eval_data, eval_target, _ in va_iter:
                    eval_batch = prepare_lm_batch(eval_data, eval_target)
                    eval_metric = parallel_eval_step(state, eval_batch)
                    eval_results.append(eval_metric)
                eval_metrics = get_metrics(eval_results)
                eval_metrics = unreplicate(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)
 
                val_loss = float(eval_metrics["eval_loss"])
                val_ppl = math.exp(val_loss)
                val_bpc = val_loss/math.log(2)
                print("-" * 100)
                if(args.dataset in ["wt103","lm1b"]):
                    print(
                        f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                        f"time: {time.time() - eval_start_time:5.2f}s | "
                        f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:9.3f}"
                    )
                    wandb.log({"eval_loss": val_loss,"eval_ppl": val_ppl}, step=train_step)
                elif(args.dataset in ["enwik8","text8"]):
                    print(
                        f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                        f"time: {time.time() - eval_start_time:5.2f}s | "
                        f"valid loss {val_loss:5.2f} | valid bpc {val_bpc:9.3f}"
                    )
                    wandb.log({"eval_loss": val_loss,"eval_bpc": val_bpc}, step=train_step)
                print("-" * 100)
                # Save best checkpoint
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    model.params = unreplicate(state).params
                    best_dir = os.path.join(save_path, f"best_{train_step}")
                    model.save_pretrained(best_dir)
                    print(f"✅ Best model saved at step {train_step}")
                    remove_old_dirs_with_prefix(save_path, "best_", train_step)
                eval_start_time = time.time()
            # Periodic checkpoint
            if train_step % args.save_frequency == 0:
                model.params = unreplicate(state).params
                last_dir = os.path.join(save_path, f"last_{train_step}")
                model.save_pretrained(last_dir)
                print(f"💾 Checkpoint saved at step {train_step}")
                remove_old_dirs_with_prefix(save_path, "last_", train_step)
                # checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=int(jax.device_get(unreplicate(state.step))),prefix="last_",keep=1,overwrite=True)                    
                print(f"Checkpoint saved at step {save_path}")
        if train_step >= args.max_step:
            print("-" * 100)
            print("End of training")
            break
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Model
    parser.add_argument("--model-path", type=str, default="", help="Path of Pretrained Model")
    parser.add_argument("--num_attention_heads", type=int, default=12, help="Number of attention heads")
    parser.add_argument("--num_key_value_heads", type=int, default=12, help="KV heads (GQA/MQA).")
    parser.add_argument("--head_dim", type=int, default=64, help="Model hidden size")
    parser.add_argument("--lmc-layer-indices",type=int,nargs="*",default=[],help="List of lmc layer indices (optional, default: empty list)")
    parser.add_argument("--finetune-mlp",action="store_true",help="Enable fine-tuning for the MLP. Default is False.")
    # Data / train
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--data-path", type=str, default="")
    parser.add_argument("--dataset", type=str, default="wt103", choices=["wt103", "lm1b", "enwik8", "text8"])
    parser.add_argument("--max_step", type=int, default=500000)
    parser.add_argument("--warmup_step", type=int, default=2000)
    parser.add_argument("--batch-size", type=int, default=96)
    parser.add_argument("--tgt_len", type=int, default=256)
    parser.add_argument("--eval_tgt_len", type=int, default=256)
    parser.add_argument("--ext_len", type=int, default=0)
    parser.add_argument("--mem_len", type=int, default=0)
    # Optim
    parser.add_argument("--learning-rate", type=float, default=0.00025, help="learning rate")
    parser.add_argument("--weight-decay-rate", type=float, default=0.01, help="weight deacy rate for lr scheduler")
    parser.add_argument('--eta_min', type=float, default=1.0e-8,help='min learning rate for cosine scheduler')
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    # Logging / ckpt
    parser.add_argument("--wandb-entity", default=None)
    parser.add_argument("--wandb-group", default=None)
    parser.add_argument("--wandb-project", default=None)
    parser.add_argument("--wandb-run-dir", default=".wandb")
    parser.add_argument("--logging-frequency", type=int, default=200)
    parser.add_argument("--eval-frequency", type=int, default=4000)
    parser.add_argument("--save-frequency", type=int, default=4000)
    parser.add_argument("--model-save-dir", type=str, default="artifacts/")
    parser.add_argument("--restore-checkpoint-path", type=str, default=None)
    main(parser.parse_args())