"""
PPO Training Script for LIBERO Environment with Buffer

This script implements PPO training with buffered trajectory collection (RLinf-style).
Key features:
- Collects multiple batches into a buffer with logprobs
- Computes group-relative advantages: (return - group_mean) / group_std
- Uses PPO-style importance sampling with clipping
- No critic needed (simpler than standard PPO)
"""

import dataclasses
import functools
import logging
import platform
from typing import Any
import time

import etils.epath as epath
import flax.nnx as nnx
from flax.training import common_utils
import flax.traverse_util as traverse_util
import jax
from jax import lax
import jax.numpy as jnp
import optax
import tqdm_loggable.auto as tqdm
import wandb
import numpy as np
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg

import openpi.models.model as _model
from openpi.models.model import Observation
import openpi.shared.array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
import openpi.training.weight_loaders as weight_loaders
from openpi.training.utils import concat_observations


def init_logging():
    """Custom logging format for better readability."""
    level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}

    class CustomFormatter(logging.Formatter):
        def format(self, record):
            record.levelname = level_mapping.get(record.levelname, record.levelname)
            return super().format(record)

    formatter = CustomFormatter(
        fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
        datefmt="%H:%M:%S",
    )

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.handlers[0].setFormatter(formatter)


def init_wandb(config: _config.RLTrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
    if not enabled:
        wandb.init(mode="disabled")
        return

    ckpt_dir = config.checkpoint_dir
    if not ckpt_dir.exists():
        raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
    if resuming:
        run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
        wandb.init(id=run_id, resume="must", project=config.project_name)
    else:
        wandb.init(
            entity="openpi_dis",
            name=config.exp_name,
            config=dataclasses.asdict(config),
            project=config.project_name,
        )
        (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)

    if log_code:
        wandb.run.log_code(epath.Path(__file__).parent.parent)


def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
    """Loads and validates the weights. Returns a loaded subset of the weights.
    
    This function allows partial loading - if some parameters are missing from the loaded weights,
    they will be kept from params_shape (and will be randomly initialized).
    """
    loaded_params = loader.load(params_shape)
    
    # Flatten both trees to compare keys
    flat_expected = traverse_util.flatten_dict(params_shape)
    flat_loaded = traverse_util.flatten_dict(loaded_params)
    
    # Validate only the keys that exist in both trees
    result = {}
    for k, expected_v in flat_expected.items():
        if k in flat_loaded:
            loaded_v = flat_loaded[k]
            result[k] = loaded_v
            logging.info(f"Loading key {k} from loaded params")
        else:
            # Keep the expected value for missing keys (will be randomly initialized)
            logging.warning(f"Expected key {k} not found in the loaded params, will be randomly initialized")
            result[k] = expected_v
    
    # Remove jax.ShapeDtypeStruct from the result. This makes sure that only the loaded params are returned.
    return traverse_util.unflatten_dict(
        {k: v for k, v in result.items() if not isinstance(v, jax.ShapeDtypeStruct)}
    )


@at.typecheck
def init_train_state(
    config: _config.RLTrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
) -> tuple[training_utils.TrainState, Any]:
    tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)

    def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
        rng, model_rng = jax.random.split(rng)
        # initialize the model (and its parameters).
        model = config.model.create(model_rng)

        # Merge the partial params into the model.
        if partial_params is not None:
            graphdef, state = nnx.split(model)
            # This will produce an error if the partial params are not a subset of the state.
            state.replace_by_pure_dict(partial_params)
            model = nnx.merge(graphdef, state)

        params = nnx.state(model)
        # Convert frozen params to bfloat16.
        params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))

        return training_utils.TrainState(
            step=0,
            params=params,
            model_def=nnx.graphdef(model),
            tx=tx,
            opt_state=tx.init(params.filter(config.trainable_filter)),
            ema_decay=config.ema_decay,
            ema_params=None if config.ema_decay is None else params,
        )

    train_state_shape = jax.eval_shape(init, init_rng)
    state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)

    if resume:
        return train_state_shape, state_sharding

    if config.pretrained_path is not None:
        weight_loader = weight_loaders.CheckpointWeightLoader(config.pretrained_path)
    else:
        weight_loader = config.weight_loader

    partial_params = _load_weights_and_validate(weight_loader, train_state_shape.params.to_pure_dict())
    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

    # Initialize the train state and mix in the partial params.
    train_state = jax.jit(
        init,
        donate_argnums=(1,),  # donate the partial params buffer.
        in_shardings=replicated_sharding,
        out_shardings=state_sharding,
    )(init_rng, partial_params)

    return train_state, state_sharding


@at.typecheck
def iql_train_step(
    config: _config.RLTrainConfig,
    rng: at.KeyArrayLike,
    state: training_utils.TrainState,
    batch: tuple,
    expectile: float = 0.7,
    gamma: float = 0.99,
    beta: float = 3.0,
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
    """IQL 训练步骤
    
    1. Q-function: 使用 TD loss
    2. Value function: 使用 expectile loss
    3. Policy: 使用 advantage-weighted diffusion loss
    
    Args:
        config: 训练配置
        rng: Random key
        state: 训练状态
        batch: 数据批次 ((current_obs, current_actions), (next_obs, next_actions), rewards)
        expectile: Expectile 参数
        gamma: 折扣因子
        beta: AWR 温度参数
        
    Returns:
        Updated training state and info dict
    """
    model = nnx.merge(state.model_def, state.params)
    model.train()
    
    train_rng = jax.random.fold_in(rng, state.step)
    
    (current_obs, current_actions), (next_obs, next_actions), rewards = batch
    
    def loss_fn(model: _model.BaseModel, train_rng: at.KeyArrayLike):
        """Combined IQL loss"""
        rng1, rng2, rng3 = jax.random.split(train_rng, 3)
        
        q_loss, q_info = model.compute_iql_q_loss(
            rng1, current_obs, current_actions, next_obs, rewards, gamma
        )
        
        target_values = jax.lax.stop_gradient(
            model.compute_q_values(rng2, current_obs, current_actions)
        )
        value_loss, value_info = model.compute_iql_value_loss(
            rng2, current_obs, target_values, expectile
        )
        
        policy_loss, policy_info = model.compute_iql_policy_loss(
            rng3, current_obs, current_actions, beta
        )
        
        # Total loss
        total_loss = q_loss + value_loss + policy_loss
        
        info = {
            **q_info, 
            **value_info, 
            **policy_info, 
            "total_loss": total_loss
        }
        return total_loss, info
    
    diff_state = nnx.DiffState(0, config.trainable_filter)
    (loss, info), grads = nnx.value_and_grad(loss_fn, argnums=diff_state, has_aux=True)(
        model, train_rng
    )
    
    params = state.params.filter(config.trainable_filter)
    updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
    new_params = optax.apply_updates(params, updates)
    
    nnx.update(model, new_params)
    new_params = nnx.state(model)
    
    new_state = dataclasses.replace(
        state, 
        step=state.step + 1, 
        params=new_params, 
        opt_state=new_opt_state
    )
    
    if state.ema_decay is not None:
        new_state = dataclasses.replace(
            new_state,
            ema_params=jax.tree.map(
                lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new,
                state.ema_params, 
                new_params
            ),
        )
    
    info["grad_norm"] = optax.global_norm(grads)
    info["update_norm"] = optax.global_norm(updates)
    
    return new_state, info


def main(config: _config.RLTrainConfig):
    init_logging()

    if config.batch_size % jax.device_count() != 0:
        raise ValueError(
            f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
        )

    jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax_new").expanduser()))

    rng = jax.random.key(config.seed)
    train_rng, init_rng, env_rng = jax.random.split(rng, 3)

    mesh = sharding.make_mesh(config.fsdp_devices)
    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

    checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
        config.checkpoint_dir,
        keep_period=config.keep_period,
        overwrite=config.overwrite,
        resume=config.resume,
    )
    init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
    
    # Initialize data loaders first to get data_sharding
    data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
    
    # Dataset weights path
    data_loader = _data_loader.create_iql_data_loader(
        config,
        sharding=data_sharding,
        shuffle=True,
    )

    data_iter = iter(data_loader)
    data_batch = next(data_iter)
    logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(data_batch)}")
    
    # Initialize train state
    train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
    jax.block_until_ready(train_state)

    logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
    
    if resuming:
        train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)

    ptrain_step = jax.jit(
        functools.partial(iql_train_step, config),
        in_shardings=(
            replicated_sharding,      # rng
            train_state_sharding,     # state
            (
                (data_sharding, data_sharding),  # current tuple
                (data_sharding, data_sharding),  # next tuple
                data_sharding,             # rewards
            ),
        ),
        out_shardings=(train_state_sharding, replicated_sharding),
        donate_argnums=(1,),
        static_argnums=(3, 4, 5),
    )
    logging.info("JIT-compiled IQL training step")

    expectile = getattr(config, 'expectile', 0.7)
    gamma = getattr(config, 'gamma', 0.99)
    beta = getattr(config, 'beta', 1.0)
    
    logging.info(f"IQL hyperparameters: expectile={expectile}, gamma={gamma}, beta={beta}")
    
    start_step = int(train_state.step)
    pbar = tqdm.tqdm(
        range(start_step, config.num_train_steps),
        initial=start_step,
        total=config.num_train_steps,
        dynamic_ncols=True,
    )

    infos = []
    
    for step in pbar:
        with sharding.set_mesh(mesh):
            batch = next(data_iter)
            train_state, info = ptrain_step(
                train_rng,
                train_state,
                batch,
                expectile,
                gamma,
                beta,
            )

        if step % config.log_interval == 0 and step != 0:
            stacked_infos = common_utils.stack_forest(infos)
            reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
            info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
            pbar.write(f"Step {step}: {info_str}")
            wandb.log(reduced_info, step=step)
            infos = []
        infos.append(info)

        if step % config.save_interval == 0 and step != 0 or step == config.num_train_steps - 1:
            jax.block_until_ready(train_state)
            _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
            logging.info(f"Saved checkpoint at step {step}")

    logging.info("Waiting for checkpoint manager to finish")
    checkpoint_manager.wait_until_finished()


if __name__ == "__main__":
    config = _config.cli()
    main(config)
