import logging
import math
import os
import pickle
import time
from collections import defaultdict
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from tokenizers import Tokenizer

import wandb
from ..impl import sampler
from ..impl import transformer

# -----------------------------------------------------------------------------#
#  Tokenization
# -----------------------------------------------------------------------------#


@partial(jax.jit, static_argnames=["pad_token_id", "eos_token_id"])
def get_keys(ids, pad_token_id, eos_token_id):
    targets = jnp.roll(ids, -1, axis=-1)
    inputs = ids.at[:, -1].set(pad_token_id)

    final_len = inputs.shape[1]
    pos_base = jnp.arange(final_len, dtype=np.int32)
    pos_tiled = jnp.tile(pos_base, (inputs.shape[0], 1))

    token_mask = ((inputs != pad_token_id) & (inputs != eos_token_id)).astype(jnp.int32)
    positions = jnp.where(token_mask, pos_tiled, -1)
    num_tokens = jnp.sum(positions >= 0, axis=-1)

    return inputs, targets, token_mask, positions, num_tokens


def dict_encode_batch(
    batch, tokenizer: Tokenizer, pad_token_id, eos_token_id, target_column
):
    encoded = tokenizer.encode_batch_fast(batch[target_column])
    ids = np.asarray([x.ids for x in encoded])

    inputs, targets, token_mask, positions, num_tokens = get_keys(
        ids, pad_token_id, eos_token_id
    )

    return {
        "inputs": inputs,
        "targets": targets,
        "positions": positions,
        "num_tokens": num_tokens,
        "token_mask": token_mask,
    }


# -----------------------------------------------------------------------------#
#  Checkpointing
# -----------------------------------------------------------------------------#


def save_checkpoint(params, opt_state, step, checkpoint_dir):
    step = str(step)
    os.makedirs(checkpoint_dir, exist_ok=True)
    ckpt_path = os.path.join(checkpoint_dir, f"checkpoint_{step}.pkl")
    with open(ckpt_path, "wb") as f:
        pickle.dump(
            {
                "params": jax.device_get(params),
                "opt_state": jax.device_get(opt_state),
                "step": step,
            },
            f,
        )
    logging.info(f"Checkpoint saved at step {step}: {ckpt_path}")

    if wandb.run: 
        try:
            artifact = wandb.Artifact(f"checkpoint-{wandb.run.id}-{step}", type="model")
            artifact.add_file(
                ckpt_path
            )
            wandb.log_artifact(artifact, aliases=[f"step_{step}", "latest"])
            logging.info(f"Logged artifact for step {step} to W&B.")
        except wandb.errors.Error as e:
            logging.warning(
                f"Could not log artifact to W&B for step {step}: {e}. Local save might still be successful."
            )
        except Exception as e: 
            logging.warning(
                f"An unexpected error occurred during W&B artifact logging for step {step}: {e}. Local save might still be successful."
            )
    else:
        logging.warning(
            f"W&B run not active, skipping artifact logging for step {step}."
        )

def load_checkpoint(path):
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    with open(path, "rb") as f:
        ckpt = pickle.load(f)
    logging.info(f"Loaded checkpoint from step {ckpt['step']}: {path}")
    return ckpt["params"], ckpt["opt_state"], ckpt["step"]


# -----------------------------------------------------------------------------#
#  Evaluation Function                                                         #
# -----------------------------------------------------------------------------#


def get_eval_fn(
    tokenizer,
    model_config,
    num_samples,
    batch_size,
    metrics_fn,
    max_tokens=128,
    temp=0.6,
    log_metrics=False,
    return_results=False,
    run_fn=transformer.run,
    verbose=False
):
    num_decoding_loops = -(-num_samples // batch_size)
    step_batch_sizes = [batch_size for _ in range(num_decoding_loops)]
    if num_samples % batch_size > 0:
        step_batch_sizes[-1] = num_samples % batch_size

    def evaluate(random_key, step, params):
        all_sequences = []

        for batch_id, batch_size in enumerate(step_batch_sizes):
            inputs = jnp.array(
                [tokenizer.bos_token_id for _ in range(batch_size)]
            ).astype(jnp.int32)[:, None]

            sequences = sampler.generate(
                tokenized_inputs=inputs,
                max_new_tokens=max_tokens,
                tokenizer=tokenizer,
                params=params,
                temp=temp,
                config=model_config,
                random_key=random.fold_in(random_key, batch_id),
                return_text=True,
                run_fn=run_fn,
                verbose=verbose
            )

            all_sequences += sequences

        stats = metrics_fn(all_sequences)

        if log_metrics:
            wandb.log({f"eval/{k}": v for k, v in stats.items()}, step=step)
        if return_results:
            return stats, all_sequences
        else:
            return stats

    return evaluate


# -----------------------------------------------------------------------------#
#  Logging Utilities                                                           #
# -----------------------------------------------------------------------------#


def update_metrics(metrics_maybe_gpu, t_start, sums, mins, maxes):
    metrics = jax.device_get(metrics_maybe_gpu)
    metrics = {k: v.item() for k, v in metrics.items()}

    step_time = time.time() - t_start

    for k, v in metrics.items():
        sums[k] += float(v)
        mins[k] = min(mins[k], v)
        maxes[k] = max(maxes[k], v)

    step_tps = metrics.get("num_valid_tokens", 0) / step_time
    sums["tokens_per_sec"] += step_tps
    mins["tokens_per_sec"] = min(mins["tokens_per_sec"], step_tps)
    maxes["tokens_per_sec"] = max(maxes["tokens_per_sec"], step_tps)

    sums["tokens"] += metrics.get("num_valid_tokens", 0)
    sums["step_time"] += step_time

    return sums, mins, maxes


def log_transformer_metrics(
    step,
    num_steps,
    sums,
    mins,
    maxes,
    window_count,
    use_wandb=True,
    pbar=None,
    log_stdout=True,
):
    avg_metrics = {k: sums[k] / window_count for k in ("loss", "accuracy", "grad_norm")}
    min_metrics = {k: mins[k] for k in avg_metrics}
    max_metrics = {k: maxes[k] for k in avg_metrics}

    tokens_per_sec_avg = sums["tokens"] / sums["step_time"]
    tokens_per_sec_min = mins["tokens_per_sec"]
    tokens_per_sec_max = maxes["tokens_per_sec"]

    log_message = (
        f"[{step + 1}/{num_steps}] "
        f"Loss {avg_metrics['loss']:.4f} "
        f"(min {min_metrics['loss']:.4f}, max {max_metrics['loss']:.4f}) | "
        f"Acc {avg_metrics['accuracy']:.3f} | "
        f"GradNorm {avg_metrics['grad_norm']:.2f} | "
        f"{tokens_per_sec_avg:.1f} tok/s"
    )
    if log_stdout or not use_wandb:
        logging.info(log_message)

    if pbar is not None:
        pbar.set_postfix(
            loss=f"{avg_metrics['loss']:.4f}",
            acc=f"{avg_metrics['accuracy']:.3f}",
        )

    if use_wandb and wandb.run is not None:
        wandb.log(
            {
                "train/loss": avg_metrics["loss"],
                "train/loss_min": min_metrics["loss"],
                "train/loss_max": max_metrics["loss"],
                "train/accuracy": avg_metrics["accuracy"],
                "train/accuracy_min": min_metrics["accuracy"],
                "train/accuracy_max": max_metrics["accuracy"],
                "train/grad_norm": avg_metrics["grad_norm"],
                "train/grad_norm_min": min_metrics["grad_norm"],
                "train/grad_norm_max": max_metrics["grad_norm"],
                "perf/tokens_per_sec": tokens_per_sec_avg,
                "perf/tokens_per_sec_min": tokens_per_sec_min,
                "perf/tokens_per_sec_max": tokens_per_sec_max,
                "progress/step": step + 1,
            },
            step=step + 1,
        )

    sums = defaultdict(float)
    mins = defaultdict(lambda: math.inf)
    maxes = defaultdict(lambda: -math.inf)

    return sums, mins, maxes
