import argparse
import os
import time
from datetime import timedelta
from typing import Any, Dict, List
from jax import debug
import copy
import json
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import wandb
from tqdm import tqdm
from datasets import Dataset
from flax.jax_utils import replicate, unreplicate
from flax.training import checkpoints, train_state
from flax.training.common_utils import get_metrics, onehot, shard
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.bert.modeling_flax_bert import FlaxBertForSequenceClassification, BertConfig
os.environ["WANDB_API_KEY"] = ""


def get_trainable_mask(params, config):
    def is_moe_param(keys):
        # Match MoE parameters in transformer/h/{moe_idx}/mlp/(gate|routed_experts_*)
        if len(keys) < 4:
            return False
        return (
            keys[0] == "bert" and keys[1] == "encoder" and keys[2] == 'layer' and keys[3] == str(0) and (keys[4] == "intermediate" or keys[4] == "output")
        )

    def label_fn(path, _):
        keys = [str(k.key) for k in path]
        return "trainable" if is_moe_param(keys) else "frozen"

    return jax.tree_util.tree_map_with_path(label_fn, params)
def print_model(flax_params, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items(): 
        name = "/".join(path)
        line = f"{name} {value.shape}"
        if file:
            print(line, file=file)
        else:
            print(line)
# ----------- Collate and Utility Functions -----------
def batch_collate_fn(batch):
    batch_dict = {key: [example[key] for example in batch] for key in batch[0]}
    result = {}
    for key, value in batch_dict.items():
        try:
            result[key] = jnp.array(value)
        except TypeError:
            result[key] = value
    return shard(result)

def split_batch(batch):
    labels = batch["label"]
    inputs = {k: v for k, v in batch.items() if k != "label"}
    return inputs, labels

# ----------- Training and Evaluation Steps -----------
def create_train_step(apply_fn):
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            inputs, labels = split_batch(batch)
            logits = apply_fn(**inputs, params=params, dropout_rng=dropout_rng, train=True)[0]
            labels_onehot = onehot(labels, num_classes=logits.shape[-1])
            loss = optax.softmax_cross_entropy(logits, labels_onehot).mean()
            predictions = jnp.argmax(logits, axis=-1)
            accuracy = jnp.mean(predictions == labels)
            return loss, accuracy

        (loss, accuracy), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        grads = jax.lax.pmean(grads, axis_name="batch")
        new_state = state.apply_gradients(grads=grads)

        metrics = {
            "loss": loss,
            "accuracy": accuracy,
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics, new_dropout_rng
    return train_step

def create_eval_step(apply_fn):
    def eval_step(state, batch):
        inputs, labels = split_batch(batch)
        logits = apply_fn(**inputs, params=state.params, train=False)[0]
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
        predictions = jnp.argmax(logits, axis=-1)
        accuracy = jnp.mean(predictions == labels)

        metrics = {"eval_loss": loss, "eval_accuracy": accuracy}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    return eval_step

# ----------- Main Training Loop -----------
def main(args: argparse.Namespace):
    wandb.init(project=args.wandb_project, name=f"TrainModel-lr{args.learning_rate}-epochs{args.num_epochs}-batch{args.batch_size}", save_code=True)
    wandb.config = dict(vars(args))

    train_dataset = Dataset.from_parquet(args.train_dataset_paths)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, drop_last=True, collate_fn=batch_collate_fn
    )

    eval_dataset = Dataset.from_parquet(args.eval_dataset_paths)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=args.batch_size, drop_last=False, collate_fn=batch_collate_fn
    )
    model_config = BertConfig.from_pretrained("bert-base-uncased",num_labels=219)
    pretrained_model = FlaxBertForSequenceClassification(model_config,input_shape=(args.batch_size,args.max_sequence_length),seed=0, dtype=jnp.dtype(args.dtype),)
    temp_model = FlaxBertForSequenceClassification(model_config,input_shape=(args.batch_size,args.max_sequence_length),seed=args.random_seed, dtype=jnp.dtype(args.dtype),)
    pretrained_params = checkpoints.restore_checkpoint(ckpt_dir=args.model_path, target={"params": pretrained_model.params})["params"]
    pretrained_model.params = pretrained_params
    finetune_model = copy.deepcopy(pretrained_model)
    finetune_model.params['bert']['encoder']['layer']['0']['intermediate'] = copy.deepcopy(temp_model.params['bert']['encoder']['layer']['0']['intermediate'])
    finetune_model.params['bert']['encoder']['layer']['0']['output'] = copy.deepcopy(temp_model.params['bert']['encoder']['layer']['0']['output'])
    # finetune_model = pretrained_model
    # print_model(finetune_model.params)
    # finetune_model.params[][] = temp_model.params[]
    num_train_steps = len(train_dataloader) * args.num_epochs
    label_mask = get_trainable_mask(finetune_model.params,model_config)

    print(json.dumps(label_mask, indent=2))
    lr_schedule_fn = optax.linear_schedule(
        init_value=args.learning_rate, end_value=0, transition_steps=num_train_steps
    )
    tx = optax.multi_transform(
        transforms={
            'trainable': optax.adamw(
                learning_rate=lr_schedule_fn,
                b1=args.adamw_beta1,
                b2=args.adamw_beta2,
                eps=args.adamw_eps,
                weight_decay=args.weight_decay_rate,
            ),
            'frozen': optax.set_to_zero()
        },
        param_labels=label_mask  # should match model.params PyTree structure
    )
    state = train_state.TrainState.create(apply_fn=finetune_model.__call__, params=finetune_model.params, tx=tx)
    if args.restore_checkpoint_path:
        state = checkpoints.restore_checkpoint(args.restore_checkpoint_path, state)
        print(f"Restored from checkpoint at {args.restore_checkpoint_path}, step {int(state.step)}")

    train_step = create_train_step(finetune_model.__call__)
    eval_step = create_eval_step(finetune_model.__call__)

    parallel_train_step = jax.pmap(train_step, axis_name="batch")
    parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

    state = replicate(state)
    rng = jax.random.PRNGKey(args.random_seed)

    train_metrics_stack = []
    last_timestamp = time.time()
    # eval_metrics = [parallel_eval_step(state, batch) for batch in tqdm(eval_dataloader)]
    # eval_metrics = get_metrics(eval_metrics)
    # eval_metrics = unreplicate(eval_metrics)
    # eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
    # print(f"loss: {eval_metrics['eval_loss']:.4f} acc: {eval_metrics['eval_accuracy']:.4f}")
    # exit()
    for epoch in range(args.num_epochs):
        print(f"\nEpoch {epoch + 1}/{args.num_epochs}")
        dropout_rngs = jax.random.split(rng, jax.local_device_count())
        train_progress = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Training", leave=False)

        for i, batch in train_progress:
            current_step = len(train_dataloader) * epoch + i
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            train_metrics_stack.append(train_metrics)
            if current_step > 0 and current_step % args.logging_frequency == 0:
                train_metrics = get_metrics(train_metrics_stack)
                train_metrics = unreplicate(train_metrics)
                train_metrics = jax.tree_util.tree_map(jnp.mean, train_metrics)

                loss = train_metrics["loss"]
                acc = train_metrics["accuracy"]
                duration = int(time.time() - last_timestamp)
                eta_secs = (num_train_steps - current_step) * duration // args.logging_frequency
                eta = timedelta(seconds=eta_secs)
                current_step = int(unreplicate(state).step)
                lr = lr_schedule_fn(current_step)

                log_metrics = {
                    "train/loss": float(loss),
                    "train/acc": float(acc),
                    "learning_rate": float(lr),
                    "epoch": epoch,
                }
                train_progress.set_postfix({k: f"{v:.5f}" if isinstance(v, float) else v for k, v in log_metrics.items() if k != "epoch"})
                wandb.log(log_metrics, step=current_step)
                train_metrics_stack, last_timestamp = [], time.time()

            is_end_of_epoch = i + 1 == len(train_dataloader)
            if current_step > 0 and (current_step % args.eval_frequency == 0 or is_end_of_epoch):
                eval_metrics = [parallel_eval_step(state, batch) for batch in eval_dataloader]
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = unreplicate(eval_metrics)
                eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)

                print(f"[EVAL] epoch: {epoch} step: {current_step}/{num_train_steps} "
                      f"loss: {eval_metrics['eval_loss']:.4f} acc: {eval_metrics['eval_accuracy']:.4f}")

                wandb.log({
                    "eval/loss": float(eval_metrics["eval_loss"]),
                    "eval/accuracy": float(eval_metrics["eval_accuracy"]),
                    "epoch": epoch
                }, step=current_step)

            if current_step > 0 and (current_step % args.save_frequency == 0 or is_end_of_epoch):
                save_path = os.path.join(args.model_save_dir,f"lr{args.learning_rate}-epochs{args.num_epochs}-batch{args.batch_size}-seed{args.random_seed}")
                checkpoints.save_checkpoint(
                    ckpt_dir=save_path,
                    target=unreplicate(state),
                    step=current_step,
                    keep=1,
                )
                print(f"Checkpoint saved at {save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="")
    parser.add_argument("--train-dataset-paths", type=str, required=True)
    parser.add_argument("--eval-dataset-paths", type=str, required=True)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--random-seed", type=int, default=0)
    parser.add_argument("--max-sequence-length", type=int, default=256)
    parser.add_argument("--num-epochs", type=int, default=2)
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--weight-decay-rate", type=float, default=0.01)
    parser.add_argument("--adamw-beta1", type=float, default=0.9)
    parser.add_argument("--adamw-beta2", type=float, default=0.999)
    parser.add_argument("--adamw-eps", type=float, default=1e-8)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
    parser.add_argument("--wandb-project", default="BERT-DBpedia")
    parser.add_argument("--logging-frequency", type=int, default=100)
    parser.add_argument("--eval-frequency", type=int, default=2500)
    parser.add_argument("--save-frequency", type=int, default=2500)
    parser.add_argument("--model-save-dir", type=str, default="./weights/dbpedia/")
    parser.add_argument("--restore-checkpoint-path", type=str)
    main(parser.parse_args())