import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
import numpy as np 
import jax
import jax.numpy as jnp
import optax
import torch
import wandb
from tqdm import tqdm
from datasets import Dataset
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config
from datamodule import  JAXTextDataset
from torch.utils.data import  DataLoader
from itertools import cycle
os.environ["WANDB_API_KEY"] = ""
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--model-config-name", type=str, default="gpt2", help="GPT2 config name (huggingface model hub)")
parser.add_argument("--data-path", type=str, default="data/enwik8", help="train datset paths (multiple paths)")
parser.add_argument("--batch-size", type=int, default=48, help="train, eval batch size (batch size will be devided by device count)")
parser.add_argument("--random-seed", type=int, default=0, help="random seed for RNG state")
parser.add_argument("--max-sequence-length", type=int, default=256, help="sequence lenght of model input")
parser.add_argument("--num-epochs", type=int, default=10, help="number of epochs")
parser.add_argument("--num-iter", type=int, default=80, help="number of training iterations")
parser.add_argument("--nb-per-iter", type=int, default=1000, help="number of batches per iteration")
parser.add_argument("--learning-rate", type=float, default=7.0e-4, help="learning rate")
parser.add_argument("--weight-decay-rate", type=float, default=0.01, help="weight deacy rate for lr scheduler")
parser.add_argument("--adamw-beta1", type=float, default=0.9)
parser.add_argument("--adamw-beta2", type=float, default=0.999)
parser.add_argument("--adamw-eps", type=float, default=1e-8)
parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32", help="model datatype")
parser.add_argument("--wandb-project", default="GPT2-Enwik8", help="wandb project name for logging")
parser.add_argument("--wandb-run-dir", default=".wandb", help="wandb run dir")
parser.add_argument("--logging-frequency", type=int, default=100, help="do logging every logging_frequency step")
parser.add_argument("--eval-frequency", type=int, default=1000, help="do evalution every eval_frequency step")
parser.add_argument("--save-frequency", type=int, default=1000, help="do saving checkpoint every save_frequencey step")
parser.add_argument("--model-save-dir", type=str, default="artifacts/", help="checkpoint saving dir")
parser.add_argument("--restore-checkpoint-path", type=str, help="if you want to restart from specific checkpoint, set this arg to checkpoint path")
# fmt: on


LOG2E = jnp.log2(jnp.e) 
def load_token_file(path):
    with open(path, 'r') as f:
        tokens = [int(t) for t in f.read().split()]
    return np.array(tokens, dtype=np.uint8)

# Use PyTorch DataLoader, returns batches of jax.numpy arrays
def jax_collate(batch):
    return shard({key: jnp.stack([item[key] for item in batch]) for key in batch[0]})

def decay_mask_fn(params):
    flat_params = flatten_dict(params)
    flat_mask = {
        path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
        for path in flat_params
    }
    return unflatten_dict(flat_mask)

def main(args: argparse.Namespace):
    os.makedirs(args.wandb_run_dir, exist_ok=True)
    wandb.init(project=args.wandb_project,name=f"TrainModel-lr{args.learning_rate}-iter{args.num_iter}-size{args.batch_size}", save_code=True)

    wandb.config = dict(vars(args))
    # Load all 3 sets
    train_data = load_token_file(os.path.join(args.data_path,'train.txt'))
    val_data   = load_token_file(os.path.join(args.data_path,'valid.txt'))
    test_data  = load_token_file(os.path.join(args.data_path,'test.txt'))
    # Create loaders
    train_loader = DataLoader(JAXTextDataset(train_data, args.max_sequence_length), batch_size=args.batch_size, shuffle=True, collate_fn=jax_collate)
    val_loader   = DataLoader(JAXTextDataset(val_data, args.max_sequence_length), batch_size=args.batch_size, collate_fn=jax_collate)
    test_loader  = DataLoader(JAXTextDataset(test_data, args.max_sequence_length), batch_size=args.batch_size, collate_fn=jax_collate)
    model_config = GPT2Config.from_pretrained(args.model_config_name)
    model_config.vocab_size = 256                   # enwik8 = 256 ASCII tokens
    model_config.n_positions = args.max_sequence_length
    model_config.n_ctx = args.max_sequence_length
    model_config.bos_token_id = None                # not used
    model_config.eos_token_id = None                # not used
    model = FlaxGPT2LMHeadModel(
        model_config,input_shape=(args.batch_size, args.max_sequence_length),seed=0,dtype=jnp.dtype(args.dtype),
    )

    num_train_steps = args.nb_per_iter * args.num_iter
    linear_decay_lr_schedule_fn = optax.linear_schedule(
        init_value=args.learning_rate, end_value=0, transition_steps=num_train_steps
    )
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=args.adamw_beta1,
        b2=args.adamw_beta2,
        eps=args.adamw_eps,
        weight_decay=args.weight_decay_rate,
        mask=decay_mask_fn,
    )
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

    if args.restore_checkpoint_path:
        state = checkpoints.restore_checkpoint(args.restore_checkpoint_path, state)
        print(f"train state restored from {args.restore_checkpoint_path}")
        print(f"skip trian step to {state.step}")
    latest_train_step = state.step
    def train_step(state: train_state.TrainState, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, num=2)
        def loss_fn(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
            return loss
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, axis_name="batch")
        new_state = state.apply_gradients(grads=grad)
        bpc = loss * LOG2E
        metrics = {"loss": loss,"bpc": bpc,"learning_rate": linear_decay_lr_schedule_fn(state.step),}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return new_state, metrics, new_dropout_rng
    def eval_step(state, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        bpc = loss * LOG2E
        metrics = {"eval_loss": loss,"eval_bpc": bpc,}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics
    parallel_train_step = jax.pmap(train_step, axis_name="batch")
    parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
    state = replicate(state)
    rng = jax.random.PRNGKey(args.random_seed)
    train_metrics_stack = []
    last_timestamp = time.time()
    train_loader_iter = cycle(train_loader)
    for iteration in range(args.num_iter):
        dropout_rngs = jax.random.split(rng, num=jax.local_device_count())
        for i in tqdm(range(args.nb_per_iter), desc=f"Iteration {iteration}"):
            batch = next(train_loader_iter)
            current_train_step = iteration * args.nb_per_iter + i
            state, train_metric, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metric)
            if current_train_step > 0 and current_train_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)
                loss = train_metrics["loss"]
                bpc = train_metrics["bpc"]
                duration = int(time.time() - last_timestamp)
                eta_secs = (num_train_steps - current_train_step) * duration // 50
                eta = timedelta(seconds=eta_secs)
                print(f"[TRAIN] iter: {iteration} step: {current_train_step}/{num_train_steps} loss: {loss:.4f} bpc: {bpc:.4f} ETA {eta}")
                wandb.log({"loss": loss, "bpc": bpc, "iteration": iteration}, step=current_train_step)
                last_timestamp, train_metrics_stack = time.time(), []
            is_end_of_iter = i + 1 == args.nb_per_iter
            if current_train_step > 0 and is_end_of_iter:
                eval_metrics = [parallel_eval_step(state, batch) for batch in val_loader]
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(jnp.mean, unreplicate(eval_metrics))
                print(f"[EVAL] iter: {iteration} step: {current_train_step}/{num_train_steps} eval_loss: {eval_metrics['eval_loss']:.4f} eval_bpc: {eval_metrics['eval_bpc']:.4f}")
                wandb.log(eval_metrics, step=current_train_step)

            if current_train_step > 0 and (current_train_step % args.save_frequency == 0 or is_end_of_iter):
                save_path = os.path.join(args.model_save_dir, f"Main-lr{args.learning_rate}-iter{args.num_iter}-size{args.batch_size}")
                checkpoints.save_checkpoint(ckpt_dir=save_path, target=unreplicate(state), step=current_train_step, keep=3)
                print(f"Saved checkpoint to {save_path}")


if __name__ == "__main__":
    main(parser.parse_args())