import argparse
import jax
import os
import optax
import torch
import wandb
import math
import time
import itertools
import numpy as np
import jax.numpy as jnp
from datetime import timedelta
from typing import Any, Dict, List
from datasets import Dataset
from tqdm import tqdm
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.llama.configuration_llama import LlamaConfig
from lmc_model import LMCFlaxLlamaForCausalLM
from data_utils import get_lm_corpus
from jax import debug


def prepare_lm_batch(data: torch.Tensor, target: torch.Tensor) -> Dict[str, Any]:
    """
    Convert and shard a language modeling batch from PyTorch to JAX.
    Args:
        data (torch.Tensor): Input data of shape (seq_len, batch)
        target (torch.Tensor): Target data of shape (seq_len, batch)
    Returns:
        Dict[str, jnp.ndarray]: Dict with 'data' and 'target', both sharded
            with shape (n_devices, batch_per_device, seq_len)
    """
    # Transpose to (batch, seq_len), then convert to jnp arrays
    input_ids =  jnp.array(data.T)
    target = jnp.array(target.T)
    # Shard across devices
    return {'input_ids': shard(input_ids),'target': shard(target)}
 

def main(args: argparse.Namespace):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.wandb_run_dir, exist_ok=True)
    run_name = (
        f"lr{args.learning_rate}-step{args.max_step}-warm{args.warmup_step}-batch{args.batch_size}"
        f"-layer{args.num_hidden_layers}-hidden{args.hidden_size}-heads{args.num_attention_heads}-type{args.mlp_type}-seed{args.seed}"
    )
    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        group=args.wandb_group,
        name=run_name,
        save_code=True,
        dir=args.wandb_run_dir,
    )
    wandb.config.update(vars(args), allow_val_change=True)
    save_path = os.path.join(args.model_save_dir, run_name)
    os.makedirs(save_path, exist_ok=True)
    # Seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    rng = jax.random.PRNGKey(args.seed)
    # Data
    corpus = get_lm_corpus(args.data_path, args.dataset)
    ntokens = len(corpus.vocab)
    args.n_token = ntokens
    eval_batch_size = 12
    tr_iter = corpus.get_iterator("train", args.batch_size, args.tgt_len, ext_len=args.ext_len)
    va_iter = corpus.get_iterator("valid", eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator("test", eval_batch_size, args.eval_tgt_len, ext_len=args.ext_len)
    # Model config
    model_config = LlamaConfig()
    model_config.vocab_size = args.n_token
    model_config.max_position_embeddings = args.tgt_len
    model_config.bos_token_id = args.n_token
    model_config.eos_token_id = args.n_token
    model_config.hidden_size = args.hidden_size
    model_config.intermediate_size = args.intermediate_size
    model_config.num_hidden_layers = args.num_hidden_layers
    model_config.num_attention_heads = args.num_attention_heads
    model_config.num_key_value_heads = args.num_key_value_heads
    model_config.hidden_act = args.hidden_act
    model_config.head_dim = args.head_dim
    model_config.attention_bias = args.attention_bias
    model_config.mlp_type = args.mlp_type
    model_config.lmc_layer_indices = args.lmc_layer_indices
    model = LMCFlaxLlamaForCausalLM(model_config,input_shape=(1, args.tgt_len),seed=0,dtype=jnp.dtype(args.dtype),)
    model.config.save_pretrained(save_path)
    num_train_steps = args.max_step
    lr_schedule =optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=args.learning_rate,
        warmup_steps=args.warmup_step,
        decay_steps=args.max_step,
        end_value=args.eta_min,
    )
    tx = optax.adamw(
        learning_rate=lr_schedule,
        
        b2=args.adamw_beta2,
        eps=args.adamw_eps,
        weight_decay=args.weight_decay_rate,
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=tx)
    if args.restore_checkpoint_path:
        state = checkpoints.restore_checkpoint(args.restore_checkpoint_path, state)
        print(f"train state restored from {args.restore_checkpoint_path}")
        print(f"skip train step to {state.step}")
    latest_train_step = state.step
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        def loss_fn(params):
            labels = batch.pop("target")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grads = grad_fn(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        new_state = state.apply_gradients(grads=grads)
        metrics = {"loss": loss,"learning_rate": lr_schedule(state.step)}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return new_state, metrics, new_dropout_rng
    def eval_step(state, batch):
        labels = batch.pop("target")
        logits = model(**batch, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        metrics = {"eval_loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, "batch")
    parallel_eval_step = jax.pmap(eval_step, "batch")
    state = replicate(state)
    train_metrics_stack = []
    train_step = int(jax.device_get(unreplicate(state.step)))
    train_loss = 0.0
    best_val_loss = float("inf")
    log_start_time = time.time()
    eval_start_time = time.time()
    print("Starting training...")
    print(f"JAX devices: {jax.devices()}")
    print(f"Using {jax.local_device_count()} devices")
    for epoch in itertools.count(start=1):
        print(f"Epoch {epoch}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_iter = tr_iter.get_varlen_iter() if getattr(args, "varlen", False) else tr_iter
        train_metrics_stack = []
        for batch_idx, (data, target, seq_len) in enumerate(tqdm(train_iter)):
            if train_step >= args.max_step:
                break
            # Prepare and shard batch
            batch = prepare_lm_batch(data, target)
            # Run train step
            state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metric)
            train_step += 1
            # Logging
            if train_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(lambda x: x.mean(), train_metrics)
                train_metrics_stack = []
                loss = float(train_metrics["loss"])
                ppl = math.exp(loss)
                bpc = loss/math.log(2)
                curr_lr = float(lr_schedule(train_step))
                elapsed = time.time() - log_start_time
                if(args.dataset in ["wt103","lm1b"]):
                    print(
                        f"| epoch {epoch:3d} step {train_step:8d} | "
                        f"{batch_idx+1:6d} batches | lr {curr_lr:.3g} "
                        f"| ms/batch {elapsed * 1000 / args.logging_frequency:5.2f} | "
                        f"loss {loss:5.2f} | ppl {ppl:9.3f}"
                    )
                    wandb.log({"loss": loss,"ppl": ppl,"learning_rate": curr_lr}, step=train_step)
                elif(args.dataset in ["enwik8","text8"]):
                    print(
                        f"| epoch {epoch:3d} step {train_step:8d} | "
                        f"{batch_idx+1:6d} batches | lr {curr_lr:.3g} "
                        f"| ms/batch {elapsed * 1000 / args.logging_frequency:5.2f} | "
                        f"loss {loss:5.2f} | bpc {bpc:9.3f}"
                    )
                    wandb.log({"loss": loss,"bpc": bpc,"learning_rate": curr_lr}, step=train_step)
                log_start_time = time.time()
            # Evaluation
            if train_step % args.eval_frequency == 0:
                eval_results = []
                for eval_data, eval_target, _ in va_iter:
                    eval_batch = prepare_lm_batch(eval_data, eval_target)
                    eval_metric = parallel_eval_step(state, eval_batch)
                    eval_results.append(eval_metric)
                eval_metrics = get_metrics(eval_results)
                eval_metrics = unreplicate(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)
 
                val_loss = float(eval_metrics["eval_loss"])
                val_ppl = math.exp(val_loss)
                val_bpc = val_loss/math.log(2)
                print("-" * 100)
                if(args.dataset in ["wt103","lm1b"]):
                    print(
                        f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                        f"time: {time.time() - eval_start_time:5.2f}s | "
                        f"valid loss {val_loss:5.2f} | valid ppl {val_ppl:9.3f}"
                    )
                    wandb.log({"eval_loss": val_loss,"eval_ppl": val_ppl}, step=train_step)
                elif(args.dataset in ["enwik8","text8"]):
                    print(
                        f"| Eval {train_step // args.eval_frequency:3d} at step {train_step:8d} | "
                        f"time: {time.time() - eval_start_time:5.2f}s | "
                        f"valid loss {val_loss:5.2f} | valid bpc {val_bpc:9.3f}"
                    )
                    wandb.log({"eval_loss": val_loss,"eval_bpc": val_bpc}, step=train_step)
                print("-" * 100)
                # Save best checkpoint
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=train_step,prefix="best_",keep=1)                    
                    print(f"Best model saved at step {train_step}")
                eval_start_time = time.time()
            # Periodic checkpoint
            if train_step % args.save_frequency == 0:
                checkpoints.save_checkpoint(ckpt_dir=save_path,target=unreplicate(state),step=train_step,prefix="last_",keep=1)                    
                print(f"Checkpoint saved at step {save_path}")
        if train_step >= args.max_step:
            print("-" * 100)
            print("End of training")
            break

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Model
    parser.add_argument("--mlp_type", type=str, default="", help="MLP type")
    parser.add_argument("--hidden_size", type=int, default=768, help="Model hidden size")
    parser.add_argument("--head_dim", type=int, default=64, help="Model hidden size")
    parser.add_argument("--intermediate_size", type=int, default=3072, help="MLP intermediate size")
    parser.add_argument("--num_hidden_layers", type=int, default=12, help="Number of transformer layers")
    parser.add_argument("--num_attention_heads", type=int, default=12, help="Number of attention heads")
    parser.add_argument("--num_key_value_heads", type=int, default=12, help="KV heads (GQA/MQA).")
    parser.add_argument("--hidden_act", type=str, default="gelu", help="Activation function")
    parser.add_argument("--attention_bias", action="store_true", help="Use bias in attention projections")
    parser.add_argument("--lmc-layer-indices",type=int,nargs="*",default=[],help="List of lmc layer indices (optional, default: empty list)")
    # Data / train
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--data-path", type=str, default="")
    parser.add_argument("--dataset", type=str, default="wt103", choices=["wt103", "lm1b", "enwik8", "text8"])
    parser.add_argument("--max_step", type=int, default=500000)
    parser.add_argument("--warmup_step", type=int, default=2000)
    parser.add_argument("--batch-size", type=int, default=96)
    parser.add_argument("--tgt_len", type=int, default=256)
    parser.add_argument("--eval_tgt_len", type=int, default=256)
    parser.add_argument("--ext_len", type=int, default=0)
    parser.add_argument("--mem_len", type=int, default=0)
    # Optim
    parser.add_argument("--learning-rate", type=float, default=0.00025, help="learning rate")
    parser.add_argument("--weight-decay-rate", type=float, default=0.01, help="weight deacy rate for lr scheduler")
    parser.add_argument('--eta_min', type=float, default=1.0e-8,help='min learning rate for cosine scheduler')
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16", help="model datatype")
    # Logging / ckpt
    parser.add_argument("--wandb-entity", default=None)
    parser.add_argument("--wandb-group", default=None)
    parser.add_argument("--wandb-project", default=None)
    parser.add_argument("--wandb-run-dir", default=".wandb")
    parser.add_argument("--logging-frequency", type=int, default=200)
    parser.add_argument("--eval-frequency", type=int, default=4000)
    parser.add_argument("--save-frequency", type=int, default=4000)
    parser.add_argument("--model-save-dir", type=str, default="artifacts/")
    parser.add_argument("--restore-checkpoint-path", type=str, default=None)
    main(parser.parse_args())