import logging
import math
import os
import time
from collections import defaultdict
from functools import partial

import jax
import jax.numpy as jnp
import optax
from tqdm.auto import tqdm

import wandb

from ..impl import caching, transformer
from . import train_utils

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def compute_metrics(logits, targets, mask):
    vocab_size = logits.shape[-1]

    valid_target_mask = (targets >= 0) & (targets < vocab_size)
    final_mask = mask & valid_target_mask
    final_mask = final_mask.astype(jnp.bool_)

    num_valid_tokens = jnp.maximum(jnp.sum(final_mask), 1)
    safe_targets = jnp.where(final_mask, targets, 0)

    per_token_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits, safe_targets
    )

    masked_loss = per_token_loss * final_mask
    mean_loss = jnp.sum(masked_loss) / num_valid_tokens

    predictions = jnp.argmax(logits, axis=-1)
    correct_predictions = (predictions == targets) * final_mask
    accuracy = jnp.sum(correct_predictions) / num_valid_tokens

    metrics = {
        "loss": mean_loss,
        "accuracy": accuracy,
        "num_valid_tokens": num_valid_tokens,
    }
    return metrics


@partial(jax.jit, static_argnums=(3, 4))
def train_step(params, opt_state, batch, config, optimizer):
    def loss_fn(p):
        input_ids = batch["inputs"]
        targets = batch["targets"]
        positions = batch["positions"]
        batch_size, seq_len = input_ids.shape

        head_dim = config["hidden_size"] // config["num_heads"]

        sin, cos = caching.build_rope(
            positions=positions,
            head_dim=head_dim,
            base=config["rope_base"],
        )

        cache = caching.TransformerCache(
            use_kv=False,
            full_sin=sin,
            full_cos=cos,
            full_positions=positions,
            residual_layers=[],
            layers=[
                caching.LayerCache(
                    sin=sin,
                    cos=cos,
                    positions=positions,
                    cached_lens=jnp.zeros((batch_size,), dtype=jnp.int32),
                    keys=None,
                    values=None,
                )
                for _ in range(config["num_layers"])
            ],
        )

        logits, *others = transformer.run(input_ids, cache, p, config)
        mask = positions >= 0
        metrics = compute_metrics(logits, targets, mask)
        return metrics["loss"], metrics

    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    metrics["grad_norm"] = optax.global_norm(grads)

    return new_params, new_opt_state, metrics


def train_model(
    config,
    data_iterator,
    num_steps,
    wandb_project,
    checkpoint_dir,
    log_every=50,
    save_every=5000,
    resume_from=None,
    seed=2002,
    use_wandb=True,
    wandb_entity=None,
    wandb_run_name=None,
    eval_fn=None,
    eval_every=None,
    optimizer=None,
    learning_rate=None,
    log_stdout=True,
):
    logging.info("Starting training...")
    logging.info(f"Config: {config}")

    key = jax.random.PRNGKey(seed)
    model_key, eval_key = jax.random.split(key)

    start_step = 0
    resumed_wandb_id = None
    if resume_from and os.path.exists(resume_from):
        params, opt_state, start_step, resumed_wandb_id = train_utils.load_checkpoint(
            resume_from
        )
        logging.info(f"Resumed training from step {start_step}")
    else:
        if resume_from:
            logging.warning(
                f"Resume checkpoint not found: {resume_from}. Initializing from scratch."
            )
        else:
            logging.info("Initializing new model parameters.")
        params = transformer.create(model_key, config)
        opt_state = None

    if optimizer is None:
        optimizer = optax.adam(learning_rate=learning_rate, b1=0.9, b2=0.999, eps=1e-8)
    if opt_state is None:
        opt_state = optimizer.init(params)

    total_params = sum(x.size for x in jax.tree_util.tree_leaves(params))
    logging.info(f"Total parameters: {total_params:,}")

    if use_wandb:
        resume_status = "allow" if resumed_wandb_id else None
        wandb_id = resumed_wandb_id

        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            name=wandb_run_name,
            config=config._dict,
            resume=resume_status,
            id=wandb_id,
        )

        banded = ["train/loss", "train/accuracy", "train/grad_norm"]

        for key in banded:
            wandb.define_metric(key, step_metric="progress/step")
            wandb.define_metric(f"{key}_min", step_metric="progress/step", hidden=True)
            wandb.define_metric(f"{key}_max", step_metric="progress/step", hidden=True)

        wandb.config.update(
            {
                "learning_rate": learning_rate,
                "num_steps": num_steps,
                "seed": seed,
                "total_parameters": total_params,
                "start_step": start_step,
            }
        )
        logging.info(
            f"Wandb initialized. Run ID: {wandb.run.id}, Resumed: {wandb.run.resumed}"
        )

    pbar = tqdm(
        range(start_step, num_steps),
        initial=start_step,
        total=num_steps,
        desc="Training",
    )

    running_sums = defaultdict(float)
    running_mins = defaultdict(lambda: math.inf)
    running_maxes = defaultdict(lambda: -math.inf)
    window_count = 0

    try:
        for step in pbar:
            step_start_time = time.time()

            # ------------ data fetch ------------
            try:
                batch = next(data_iterator)
                batch = jax.tree_util.tree_map(jnp.asarray, batch)
            except StopIteration:
                logging.warning("Data iterator exhausted. Stopping training.")
                break

            # ------------ train step ------------

            params, opt_state, metrics = train_step(
                params, opt_state, batch, config, optimizer
            )

            # ------------ logs, checkpoints, evals ------------

            running_sums, running_mins, running_maxes = train_utils.update_metrics(
                metrics, step_start_time, running_sums, running_mins, running_maxes
            )
            window_count += 1

            if (step + 1) % log_every == 0:
                running_sums, running_mins, running_maxes = (
                    train_utils.log_transformer_metrics(
                        step=step,
                        num_steps=num_steps,
                        sums=running_sums,
                        mins=running_mins,
                        maxes=running_maxes,
                        window_count=window_count,
                        use_wandb=use_wandb,
                        pbar=pbar,
                        log_stdout=log_stdout,
                    )
                )
                window_count = 0

            if (step + 1) % save_every == 0:
                train_utils.save_checkpoint(params, opt_state, step + 1, checkpoint_dir)

            if eval_fn and (step + 1) % eval_every == 0:
                eval_fn(eval_key, step + 1, params)

        train_utils.save_checkpoint(params, opt_state, num_steps, checkpoint_dir)
        if eval_fn:
            eval_fn(eval_key, step + 1, params)

    except Exception as e:
        logging.warning(f"Unexpected exception encountered: {e}")

    finally:
        if use_wandb and wandb.run is not None:
            final_step = step + 1 if "step" in locals() else start_step
            if final_step > start_step:
                final_metrics = {f"final_{k}": v for k, v in metrics.items()}
                wandb.log(final_metrics, step=final_step)

            wandb.finish()
            logging.info("Wandb run finished.")

    return params, opt_state
