import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
from jax import debug

import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import wandb
from tqdm import tqdm
from datasets import Dataset
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.bert.modeling_flax_bert import FlaxBertForSequenceClassification, BertConfig
os.environ["WANDB_API_KEY"] = ""

# ----------- Collate and Utility Functions -----------
def batch_collate_fn(batch):
    batch_dict = {key: [example[key] for example in batch] for key in batch[0]}
    result = {}
    for key, value in batch_dict.items():
        try:
            result[key] = jnp.array(value)
        except TypeError:
            result[key] = value
    return shard(result)

def split_batch(batch):
    labels = batch["label"]
    inputs = {k: v for k, v in batch.items() if k != "label"}
    return inputs, labels

# ----------- Training and Evaluation Steps -----------
def create_train_step(apply_fn):
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            inputs, labels = split_batch(batch)
            logits = apply_fn(**inputs, params=params, dropout_rng=dropout_rng, train=True)[0]
            labels_onehot = onehot(labels, num_classes=logits.shape[-1])
            loss = optax.softmax_cross_entropy(logits, labels_onehot).mean()
            predictions = jnp.argmax(logits, axis=-1)
            accuracy = jnp.mean(predictions == labels)
            return loss, accuracy

        (loss, accuracy), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        new_state = state.apply_gradients(grads=grads)

        metrics = {
            "loss": loss,
            "accuracy": accuracy,
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics, new_dropout_rng
    return train_step

def create_eval_step(apply_fn):
    def eval_step(state, batch):
        inputs, labels = split_batch(batch)
        logits = apply_fn(**inputs, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        predictions = jnp.argmax(logits, axis=-1)
        accuracy = jnp.mean(predictions == labels)

        metrics = {"eval_loss": loss, "eval_accuracy": accuracy}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    return eval_step

# ----------- Main Training Loop -----------
def main(args: argparse.Namespace):
    wandb.init(project=args.wandb_project, name=f"TrainModel-lr{args.learning_rate}-epochs{args.num_epochs}-batch{args.batch_size}", save_code=True)
    wandb.config = dict(vars(args))

    train_dataset = Dataset.from_parquet(args.train_dataset_paths)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, drop_last=True, collate_fn=batch_collate_fn
    )

    eval_dataset = Dataset.from_parquet(args.eval_dataset_paths)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=args.batch_size, drop_last=False, collate_fn=batch_collate_fn
    )
    model_config = BertConfig.from_pretrained(args.model_config_name,num_labels=219)
    model = FlaxBertForSequenceClassification(model_config,input_shape=(args.batch_size, 
        args.max_sequence_length),seed=0, dtype=jnp.dtype(args.dtype),)
    num_train_steps = len(train_dataloader) * args.num_epochs
    lr_schedule_fn = optax.linear_schedule(
        init_value=args.learning_rate, end_value=0, transition_steps=num_train_steps
    )

    optimizer = optax.adamw(
        learning_rate=lr_schedule_fn,
        b1=args.adamw_beta1,
        b2=args.adamw_beta2,
        eps=args.adamw_eps,
        weight_decay=args.weight_decay_rate,
    )

    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
    if args.restore_checkpoint_path:
        state = checkpoints.restore_checkpoint(args.restore_checkpoint_path, state)
        print(f"Restored from checkpoint at {args.restore_checkpoint_path}, step {int(state.step)}")

    train_step = create_train_step(model.__call__)
    eval_step = create_eval_step(model.__call__)

    parallel_train_step = jax.pmap(train_step, axis_name="batch")
    parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

    state = replicate(state)
    rng = jax.random.PRNGKey(args.random_seed)

    train_metrics_stack = []
    last_timestamp = time.time()

    for epoch in range(args.num_epochs):
        print(f"\nEpoch {epoch + 1}/{args.num_epochs}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_progress = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Training", leave=False)

        for i, batch in train_progress:
            current_step = len(train_dataloader) * epoch + i
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metrics)
            if current_step > 0 and current_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)

                loss = train_metrics["loss"]
                acc = train_metrics["accuracy"]
                duration = int(time.time() - last_timestamp)
                eta_secs = (num_train_steps - current_step) * duration // args.logging_frequency
                eta = timedelta(seconds=eta_secs)
                current_step = int(unreplicate(state).step)
                lr = lr_schedule_fn(current_step)

                log_metrics = {
                    "train/loss": float(loss),
                    "train/acc": float(acc),
                    "learning_rate": float(lr),
                    "epoch": epoch,
                }
                train_progress.set_postfix({k: f"{v:.5f}" if isinstance(v, float) else v for k, v in log_metrics.items() if k != "epoch"})
                wandb.log(log_metrics, step=current_step)
                train_metrics_stack, last_timestamp = [], time.time()

            is_end_of_epoch = i + 1 == len(train_dataloader)
            if current_step > 0 and (current_step % args.eval_frequency == 0 or is_end_of_epoch):
                eval_metrics = [parallel_eval_step(state, batch) for batch in eval_dataloader]
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = unreplicate(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)

                print(f"[EVAL] epoch: {epoch} step: {current_step}/{num_train_steps} "
                      f"loss: {eval_metrics['eval_loss']:.4f} acc: {eval_metrics['eval_accuracy']:.4f}")

                wandb.log({
                    "eval/loss": float(eval_metrics["eval_loss"]),
                    "eval/accuracy": float(eval_metrics["eval_accuracy"]),
                    "epoch": epoch
                }, step=current_step)

            if current_step > 0 and (current_step % args.save_frequency == 0 or is_end_of_epoch):
                save_path = os.path.join(args.model_save_dir,f"lr{args.learning_rate}-epochs{args.num_epochs}-batch{args.batch_size}")
                checkpoints.save_checkpoint(
                    ckpt_dir=save_path,
                    target=unreplicate(state),
                    step=current_step,
                    keep=3,
                )
                print(f"Checkpoint saved at {save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-config-name", type=str, default="bert-base-uncased")
    parser.add_argument("--train-dataset-paths", type=str, required=True)
    parser.add_argument("--eval-dataset-paths", type=str, required=True)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--random-seed", type=int, default=0)
    parser.add_argument("--max-sequence-length", type=int, default=256)
    parser.add_argument("--num-epochs", type=int, default=2)
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--weight-decay-rate", type=float, default=0.01)
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
    parser.add_argument("--wandb-project", default="BERT-DBpedia")
    parser.add_argument("--logging-frequency", type=int, default=100)
    parser.add_argument("--eval-frequency", type=int, default=2500)
    parser.add_argument("--save-frequency", type=int, default=2500)
    parser.add_argument("--model-save-dir", type=str, default="./weights/dbpedia/")
    parser.add_argument("--restore-checkpoint-path", type=str)
    main(parser.parse_args())