import dataclasses
import functools
import logging
import time
from typing import Any

import tyro
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import tqdm

from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.shared import array_typing as at
from openpi.training import checkpoints as _checkpoints
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader
from openpi.training import optimizer as _optimizer
from openpi.training import utils as training_utils
from openpi.shared import nnx_utils
from openpi.shared import path_utils



@dataclasses.dataclass(frozen=True)
class ComparatorTrainConfig(_config.TrainConfig):
    """Stage-2 training: freeze VLM + action expert; train comparator + learnable queries.

    Expects dataset to return (observation, action_a, action_b, label) via a custom DataLoader.
    """
    name: str = "pi05_comparator"
    # Enable comparator in model config
    model: _model.BaseModelConfig = dataclasses.field(
        default_factory=lambda: pi0_config.Pi0Config(enable_action_comparator=True, pi05=True, action_horizon=10, discrete_state_input=False)
    )
    # Stage-1 weights to initialize action expert for sampling pairs
    stage1_params_dir: str | None = dataclasses.field(
        default_factory=lambda: path_utils.env_path(
            "OPENPI_COMPARATOR_STAGE1_PARAMS_DIR",
            "OPENPI_STAGE1_PARAMS_DIR",
        )
        or None
    )
    # Pair construction settings
    good_steps: int = 10
    bad_steps: int = 3
    include_gt_pair: bool = True  # (GT as good) vs (bad_steps sample)
    include_model_pair: bool = True  # (good_steps sample) vs (bad_steps sample)
    randomize_order: bool = True
    data=_config.LeRobotLiberoDataConfig(
        repo_id="physical-intelligence/libero",
        base_config=_config.DataConfig(prompt_from_task=True),
        extra_delta_transform=False,
    ),
    lr_schedule=_optimizer.CosineDecaySchedule(
        warmup_steps=10_000,
        peak_lr=5e-5,
        decay_steps=1_000_000,
        decay_lr=5e-5,
    ),
    optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),
    ema_decay=0.999,
    num_train_steps=30_000,
    


def init_train_state(config: ComparatorTrainConfig, mesh) -> tuple[training_utils.TrainState, Any, ComparatorTrainConfig]:
    key = jax.random.key(config.seed)
    model = config.model.create(key)
    # Freeze everything except comparator and comparison_queries
    comparator_filter = nnx.Any(
        nnx_utils.PathRegex(".*action_comparator.*"),
        nnx_utils.PathRegex(".*comparison_queries.*"),
    )
    def freeze_filter(path, _):
        return not comparator_filter(path, _)
    train_config = dataclasses.replace(config, freeze_filter=freeze_filter)

    params = nnx.state(model)
    tx = _optimizer.create_optimizer(train_config.optimizer, train_config.lr_schedule)
    train_state = training_utils.TrainState(
        step=jnp.array(0),
        params=params,
        model_def=nnx.split(model)[0],
        opt_state=tx.init(params.filter(train_config.trainable_filter)),
        tx=tx,
        ema_decay=train_config.ema_decay,
        ema_params=params if train_config.ema_decay is not None else None,
    )
    return train_state, None, train_config


def bce_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    probs = jax.nn.sigmoid(logits)
    labels = labels.astype(jnp.float32)
    return - (labels * jnp.log(probs + 1e-8) + (1 - labels) * jnp.log(1 - probs + 1e-8))


def train_step(config: ComparatorTrainConfig, rng, state: training_utils.TrainState, batch):
    model = nnx.merge(state.model_def, state.params)
    model.train()
    observation, action_a, action_b, label = batch

    @at.typecheck
    def loss_fn(model: _model.BaseModel, rng: at.KeyArrayLike):
        # Single VLM prefill with queries
        context = model._prefill_vlm_with_queries(observation)
        logits = model.compare_actions_with_context(observation, context, action_a, action_b)
        loss = jnp.mean(bce_loss(logits.squeeze(-1), label))
        return loss

    train_rng = jax.random.fold_in(rng, state.step)
    diff_state = nnx.DiffState(0, config.trainable_filter)
    loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng)
    # Ensure grads tree matches optimizer state tree
    grads = grads.filter(config.trainable_filter)

    params = state.params.filter(config.trainable_filter)
    updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
    new_params = optax.apply_updates(params, updates)
    nnx.update(model, new_params)
    new_params = nnx.state(model)

    new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
    if state.ema_decay is not None:
        new_state = dataclasses.replace(
            new_state,
            ema_params=jax.tree.map(lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params),
        )

    info = {
        "loss": loss,
    }
    return new_state, info


def build_pairwise_iterator(
    config: ComparatorTrainConfig,
    base_loader,
    model: _model.BaseModel,
):
    """Wrap a (Observation, Actions) loader to yield (Observation, action_a, action_b, label) pairs.

    For each observation in the base loader:
      - If include_gt_pair: (GT, sample(bad_steps)) with label=1 for GT
      - If include_model_pair: (sample(good_steps), sample(bad_steps)) with label=1 for good_steps
      - Randomize A/B order if required
    """
    rng = jax.random.key(config.seed)

    def maybe_flip(rng_key, a, b, y):
        # flip shape: (B,) to broadcast over actions (B, H, D)
        if config.randomize_order:
            flip = jax.random.bernoulli(rng_key, p=0.5, shape=(a.shape[0],))
        else:
            flip = jnp.zeros((a.shape[0],), dtype=bool)
        a2 = jnp.where(flip[:, None, None], b, a)
        b2 = jnp.where(flip[:, None, None], a, b)
        y2 = jnp.where(flip, 1 - y, y)
        return a2, b2, y2

    for observation, gt_actions in base_loader:
        # Shapes: observation (B,...), gt_actions (B, H, D)
        bsize = gt_actions.shape[0]
        # Sample actions with different steps under single prefix per batch element
        rng, r1, r2 = jax.random.split(rng, 3)
        bad = model.sample_actions(r1, observation, num_steps=config.bad_steps)
        good_model = model.sample_actions(r2, observation, num_steps=config.good_steps)

        batches = []
        # Pair 1: (GT, bad)
        if config.include_gt_pair:
            y = jnp.ones((bsize,), dtype=jnp.float32)
            rng, ra = jax.random.split(rng)
            a, b, y = maybe_flip(ra, gt_actions, bad, y)
            batches.append((observation, a, b, y))
        # Pair 2: (good_model, bad)
        if config.include_model_pair:
            y = jnp.ones((bsize,), dtype=jnp.float32)
            rng, rb = jax.random.split(rng)
            a, b, y = maybe_flip(rb, good_model, bad, y)
            batches.append((observation, a, b, y))

        for item in batches:
            yield item


def main(config: ComparatorTrainConfig):
    logging.basicConfig(level=logging.INFO, force=True)
    mesh = jax.sharding.Mesh(jax.devices(), ("x",))
    train_state, _, train_config = init_train_state(config, mesh)

    # Base loader yields (Observation, Actions)
    base_loader = _data_loader.create_data_loader(train_config)

    # Initialize model weights from stage-1 if provided
    if train_config.stage1_params_dir is not None:
        params = _model.restore_params(train_config.stage1_params_dir, dtype=jnp.bfloat16)
        model_for_sampling = train_config.model.load(params)
    else:
        # Use current model (weights include comparator random init; sampling uses action expert path)
        model_for_sampling = nnx.merge(train_state.model_def, train_state.params)

    # Build pairwise iterator
    pairwise_iter = build_pairwise_iterator(train_config, iter(base_loader), model_for_sampling)

    pbar = tqdm.tqdm(range(config.num_train_steps), total=config.num_train_steps, dynamic_ncols=True)
    infos = []
    ckpt_mngr, resuming = _checkpoints.initialize_checkpoint_dir(
        train_config.checkpoint_dir, keep_period=train_config.keep_period, overwrite=train_config.overwrite, resume=train_config.resume
    )
    if resuming:
        train_state = _checkpoints.restore_state(ckpt_mngr, train_state, base_loader)

    prun = jax.jit(functools.partial(train_step, train_config))
    #batch = next(pairwise_iter)
    for step in pbar:
        rng = jax.random.key(train_config.seed + step)
        batch = next(pairwise_iter)
        train_state, info = prun(rng, train_state, batch)
        infos.append(info)
        if step % train_config.log_interval == 0:
            window = infos[-train_config.log_interval:]
            reduced = jax.device_get(
                jax.tree.map(lambda *xs: jnp.mean(jnp.stack(xs)), *window)
            )
            msg = ", ".join(f"{k}={float(v):.4f}" for k, v in reduced.items())
            pbar.write(f"Step {step}: {msg}")
        if (step % train_config.save_interval == 0 and step > 0) or step == train_config.num_train_steps - 1:
            _checkpoints.save_state(ckpt_mngr, train_state, base_loader, step)


if __name__ == "__main__":
    main(tyro.cli(ComparatorTrainConfig))