from __future__ import annotations

import os
import json
import argparse
import warnings
import re
import random

import jax
import jax.numpy as jnp
import webdataset as wds
import numpy as np
import tqdm
import wandb
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from flax.serialization import msgpack_serialize

from dataset import create_dataloaders
from training_common import create_train_state, training_step
from utils import AverageMeter, save_checkpoint_in_background
from misc import load_config

warnings.filterwarnings("ignore")


def main(args: argparse.Namespace):
    train_dataloader, valid_dataloader = create_dataloaders(args)
    train_dataloader_iter = iter(train_dataloader)
    state = create_train_state(args).replicate()
    if jax.process_index() == 0:
        wandb.init(name=args.name, project=args.project, config=args)
    average_meter, max_val_acc1 = AverageMeter(use_latest=["learning_rate"]), 0.0

    gumbel_temperature = args.get("gumbel_temperature")
    gumbel_noise_coeff = args.get("gumbel_noise_coeff")
    rl_coeff = args.get("rl_coeff")
    kl_coeff = args.get("kl_coeff")
    # rl_coeff_start = 0
    # rl_coeff_end = 1
    # rl_coeff = rl_coeff_start

    for step in tqdm.trange(1, args.training_steps + 1, dynamic_ncols=True):
        k_step = (
            random.randint(args.top_k_range.min, args.top_k_range.max + 1)
            if args.top_k_range is not None
            else args.top_k
        )
        k_patches = shard(
            jnp.full((jax.local_device_count(),), k_step, dtype=jnp.int32)
        )
        gumbel_temp = shard(
            jnp.full((jax.local_device_count(),), gumbel_temperature, dtype=jnp.float32)
        )
        gumbel_noise = shard(
            jnp.full((jax.local_device_count(),), gumbel_noise_coeff, dtype=jnp.float32)
        )
        rl_coefficient = shard(
            jnp.full((jax.local_device_count(),), rl_coeff, dtype=jnp.float32)
        )
        kl_coefficient = shard(
            jnp.full((jax.local_device_count(),), kl_coeff, dtype=jnp.float32)
        )

        for _ in range(args.grad_accum):
            batch = shard(jax.tree_map(np.asarray, next(train_dataloader_iter)))
            state, metrics = training_step(
                state,
                batch=(
                    *batch,
                    k_patches,
                    gumbel_temp,
                    gumbel_noise,
                    rl_coefficient,
                    kl_coefficient,
                ),
            )
            average_meter.update(**unreplicate(metrics))

        # Update the gumbel temperature and noise
        # if args.patch_selection_method == "gumbel-topk":
        #     gumbel_temperature = max(
        #         args.gumbel_temperature_min,
        #         args.gumbel_temperature
        #         * args.gumbel_temperature_decay
        #         ** (step * (1 / args.gumbel_temperature_decay_steps)),
        #     )
        #     gumbel_noise_coeff = max(
        #         args.gumbel_noise_coeff_min,
        #         args.gumbel_noise_coeff
        #         * args.gumbel_noise_coeff_decay
        #         ** (step * (1 / args.gumbel_noise_coeff_decay_steps)),
        #     )
        # if args.patch_selection_method == "reinforcement:reinforce":
        #     rl_coeff = rl_coeff_start + (rl_coeff_end - rl_coeff_start) * (
        #         step / args.training_steps
        #     )

        if (
            jax.process_index() == 0
            and args.log_interval > 0
            and step % args.log_interval == 0
        ):
            metrics = average_meter.summary(prefix="train/")
            metrics["processed_samples"] = step * args.train_batch_size
            metrics["gumbel_temperature"] = gumbel_temperature
            metrics["gumbel_noise_coeff"] = gumbel_noise_coeff
            metrics["rl_coeff"] = rl_coeff
            metrics["kl_coeff"] = kl_coeff
            wandb.log(metrics, step)

    if jax.process_index() == 0:
        params_bytes = msgpack_serialize(unreplicate(state.params))

        # *** SANITIZE THE NAME ***
        # Replace problematic characters with underscores (or remove them)
        # safe_name = re.sub(r"[\[\]':]", '_', args.name)
        # Or more aggressively, keep only alphanumeric, underscore, hyphen, dot
        safe_name = re.sub(r"[^a-zA-Z0-9_.-]+", "_", args.name)

        if args.output_dir.startswith("gs://"):
            # Use the sanitized name for GCS paths
            config_path = os.path.join(args.output_dir, f"{safe_name}-config.msgpack")
            # Make sure save_checkpoint_in_background ALSO uses the safe_name
            save_checkpoint_in_background(
                args, params_bytes, postfix="last", safe_run_name=safe_name
            )  # Pass safe name

        else:  # Local saving
            os.makedirs(args.output_dir, exist_ok=True)
            # You might still want to use safe_name for local files for consistency
            config_path = os.path.join(args.output_dir, f"{safe_name}-config.msgpack")
            # Call local saving for checkpoint
            save_checkpoint_in_background(
                args, params_bytes, postfix="last", safe_run_name=safe_name
            )  # Pass safe name

        # Save config (now uses sanitized path if on GCS)
        with wds.gopen(config_path, "wb") as fp:
            fp.write(json.dumps(args.as_dict()).encode("utf-8"))

        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", default="config/imagenet.yaml")
    parser.add_argument("--train-dataset-shards")
    parser.add_argument("--valid-dataset-shards")

    parser.add_argument("--name")
    parser.add_argument("--ipaddr")
    parser.add_argument("--hostname")
    parser.add_argument("--output-dir")
    config_in = parser.parse_args()

    args = load_config(
        config_path=os.path.abspath(config_in.config_path),
    )
    # Update args with command line arguments
    args.train_dataset_shards = config_in.train_dataset_shards
    args.valid_dataset_shards = config_in.valid_dataset_shards
    args.name = config_in.name
    args.ipaddr = config_in.ipaddr
    args.hostname = config_in.hostname
    args.output_dir = config_in.output_dir

    print(args)
    main(args)
