import dataclasses
import functools
import logging
import platform
from typing import Any, Callable
from PIL import Image

import etils.epath as epath
import flax.nnx as nnx
from flax.training import common_utils
import flax.traverse_util as traverse_util
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import optax
import tqdm_loggable.auto as tqdm
import wandb

import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders


def init_logging():
    """Custom logging format for better readability."""
    level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}

    class CustomFormatter(logging.Formatter):
        def format(self, record):
            record.levelname = level_mapping.get(record.levelname, record.levelname)
            return super().format(record)

    formatter = CustomFormatter(
        fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
        datefmt="%H:%M:%S",
    )

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.handlers[0].setFormatter(formatter)


def init_wandb(config: _config.AWRTrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
    if not enabled:
        wandb.init(mode="disabled")
        return
    mode = 'offline' if config.wandb_offline else 'online'
    ckpt_dir = config.checkpoint_dir
    if not ckpt_dir.exists():
        raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
    if resuming:
        run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
        wandb.init(id=run_id, resume="must", project=config.project_name, mode=mode)
    else:
        wandb.init(
            name=config.exp_name,
            config=dataclasses.asdict(config),
            project=config.project_name,
            mode=mode,
        )
        (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)

    if log_code:
        wandb.run.log_code(epath.Path(__file__).parent.parent)


def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
    """Loads and validates the weights. Returns a loaded subset of the weights."""
    loaded_params = loader.load(params_shape)
    at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)

    # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
    return traverse_util.unflatten_dict(
        {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
    )


@at.typecheck
def init_train_state(
    config: _config.AWRTrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
) -> tuple[training_utils.TrainState, Any]:
    tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)

    def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
        rng, model_rng = jax.random.split(rng)
        # initialize the model (and its parameters).
        model = config.model.create(model_rng)

        # Merge the partial params into the model.
        if partial_params is not None:
            graphdef, state = nnx.split(model)
            # This will produce an error if the partial params are not a subset of the state.
            state.replace_by_pure_dict(partial_params)
            model = nnx.merge(graphdef, state)

        params = nnx.state(model)
        # Convert frozen params to bfloat16.
        params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))

        return training_utils.TrainState(
            step=0,
            params=params,
            model_def=nnx.graphdef(model),
            tx=tx,
            opt_state=tx.init(params.filter(config.trainable_filter)),
            ema_decay=config.ema_decay,
            ema_params=None if config.ema_decay is None else params,
        )

    train_state_shape = jax.eval_shape(init, init_rng)
    state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)

    if resume:
        return train_state_shape, state_sharding

    partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

    # Initialize the train state and mix in the partial params.
    train_state = jax.jit(
        init,
        donate_argnums=(1,),  # donate the partial params buffer.
        in_shardings=replicated_sharding,
        out_shardings=state_sharding,
    )(init_rng, partial_params)

    return train_state, state_sharding

@at.typecheck
def init_value_model(
    config: _config.ValueTrainConfig, 
    rng: at.KeyArrayLike,
) -> _model.BaseModel | None:
    """Initialize value model if configured."""
    if config is None:
        return None
        
    # Create value model
    rng, value_rng = jax.random.split(rng)
    value_model = config.model.create(value_rng)
    
    # Load weights if path is provided
    if config.pretrained_path is not None:
        value_weight_loader = _weight_loaders.CheckpointWeightLoader(config.pretrained_path)
        value_params_shape = nnx.state(value_model)
        loaded_value_params = _load_weights_and_validate(value_weight_loader, value_params_shape.to_pure_dict())
        
        # Update model with loaded parameters
        graphdef, state = nnx.split(value_model)
        state.replace_by_pure_dict(loaded_value_params)
        value_model = nnx.merge(graphdef, state)
        
        logging.info(f"Loaded value model from {config.pretrained_path}")
    
    # Set model to evaluation mode
    value_model.eval()
    
    return value_model


@at.typecheck
def train_step(
    config: _config.AWRTrainConfig,
    get_weight: Callable[[at.KeyArrayLike, _model.Observation, _model.Actions], at.Float[at.Array, "b"]],
    rng: at.KeyArrayLike,
    state: training_utils.TrainState,
    optimal_batch: tuple[_model.Observation, _model.Actions],
    suboptimal_batch: tuple[_model.Observation, _model.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
    model = nnx.merge(state.model_def, state.params)
    model.train()

    @at.typecheck
    def loss_fn(
        model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, weight: at.Float[at.Array, "b a"]
    ):
        loss = model.compute_loss(rng, observation, actions, weight, train=True)
        B = loss.shape[0]
        info = {
            "optimal_loss": jnp.mean(loss[:B//2]),
            "suboptimal_loss": jnp.mean(loss[B//2:]),
        }
        return jnp.mean(loss), info

    train_rng = jax.random.fold_in(rng, state.step)
    optimal_observation, optimal_actions = optimal_batch
    suboptimal_observation, suboptimal_actions = suboptimal_batch
    weight = get_weight(suboptimal_observation, suboptimal_actions)  

    observation = training_utils.concat_observations([optimal_observation, suboptimal_observation])
    actions = jnp.concatenate([optimal_actions, suboptimal_actions], axis=0)
    weight = jnp.concatenate([jnp.ones_like(weight), weight], axis=0)

    # Filter out frozen params.
    diff_state = nnx.DiffState(0, config.trainable_filter) 

    (loss, loss_info), grads = nnx.value_and_grad(loss_fn, argnums=diff_state, has_aux=True)(model, train_rng, observation, actions, weight)

    params = state.params.filter(config.trainable_filter)
    updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
    new_params = optax.apply_updates(params, updates)

    # Update the model in place and return the new full state.
    nnx.update(model, new_params)
    new_params = nnx.state(model)

    new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
    if state.ema_decay is not None:
        new_state = dataclasses.replace(
            new_state,
            ema_params=jax.tree.map(
                lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
            ),
        )

    # Filter out params that aren't kernels.
    kernel_params = nnx.state(
        model,
        nnx.All(
            nnx.Param,
            nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
            lambda _, x: x.value.ndim > 1,
        ),
    )
    info = {
        "weight_mean": weight.mean(),
        "weight_min": weight.min(),
        "weight_max": weight.max(),
        "grad_norm": optax.global_norm(grads),
        "param_norm": optax.global_norm(kernel_params),
    }
    info.update(loss_info)
    return new_state, info


def main(config: _config.AWRTrainConfig, value_config: _config.ValueTrainConfig):
    init_logging()
    logging.info(f"Running on: {platform.node()}")

    if config.batch_size % jax.device_count() != 0:
        raise ValueError(
            f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
        )

    jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))

    rng = jax.random.key(config.seed)
    train_rng, init_rng = jax.random.split(rng)

    mesh = sharding.make_mesh(config.fsdp_devices)
    data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

    checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
        config.checkpoint_dir,
        keep_period=config.keep_period,
        overwrite=config.overwrite,
        resume=config.resume,
    )
    init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)

    optimal_data_loader, suboptimal_data_loader = _data_loader.create_awr_data_loader(
        config,
        sharding=data_sharding,
        shuffle=True,
    )
    optimal_data_iter = iter(optimal_data_loader)
    suboptimal_data_iter = iter(suboptimal_data_loader)
    optimal_batch = next(optimal_data_iter)
    suboptimal_batch = next(suboptimal_data_iter)
    logging.info(f"Initialized optimal data loader:\n{training_utils.array_tree_to_info(optimal_batch)}")
    logging.info(f"Initialized suboptimal data loader:\n{training_utils.array_tree_to_info(suboptimal_batch)}")

    train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
    jax.block_until_ready(train_state)
    logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")

    if resuming:
        train_state = _checkpoints.restore_state(checkpoint_manager, train_state, optimal_data_loader)

    value_rng = jax.random.fold_in(init_rng, 0)
    value_model = init_value_model(value_config, value_rng)

    get_weight = nnx_utils.module_jit(value_model.get_weight)

    ptrain_step = jax.jit(
        functools.partial(train_step, config, get_weight),
        in_shardings=(replicated_sharding, train_state_sharding, data_sharding, data_sharding),
        out_shardings=(train_state_sharding, replicated_sharding),
        donate_argnums=(1,),
    )

    start_step = int(train_state.step)
    pbar = tqdm.tqdm(
        range(start_step, config.num_train_steps),
        initial=start_step,
        total=config.num_train_steps,
        dynamic_ncols=True,
    )

    infos = []
    x = []
    for step in pbar:
        if step == 6976:
            continue
        with sharding.set_mesh(mesh):
            train_state, info = ptrain_step(train_rng, train_state, optimal_batch, suboptimal_batch)
            # train_state, info = train_step(config, get_weight, train_rng, train_state, optimal_batch, suboptimal_batch)
        infos.append(info)
        
        optimal_loss = np.array(info["optimal_loss"])   
        suboptimal_loss = np.array(info["suboptimal_loss"])
        if (np.isnan(optimal_loss) or np.isnan(suboptimal_loss) or 
            optimal_loss > 1000000000 or suboptimal_loss > 1000000000 or 
            optimal_loss < 0 or suboptimal_loss < 0):
            # Save debug information when loss is problematic
            import os
            print(info)
            
            debug_dir = epath.Path("debug_outputs")
            debug_dir.mkdir(exist_ok=True)

            optimal_observation, optimal_actions = optimal_batch
            suboptimal_observation, suboptimal_actions = suboptimal_batch
            observation = training_utils.concat_observations([optimal_observation, suboptimal_observation])

            # Save images from observation as PNG
            for k, v in observation.images.items():
                for i, img in enumerate(v):   
                    if img.ndim == 3:  # [H, W, C]
                        img_array = np.array(img)
                        if img_array.shape[2] == 1:  # Grayscale
                            img_array = img_array.squeeze(2)
                        elif img_array.shape[2] == 3:  # RGB
                            img_array = (img_array.clip(-1, 1) * 255).astype(np.uint8)
                        else:
                            img_array = img_array[:, :, 0]  # Take first channel
                        
                        pil_img = Image.fromarray(img_array)
                        os.makedirs(debug_dir / f"observation_{k}", exist_ok=True)
                        pil_img.save(debug_dir / f"observation_{k}/ {i}.png")
            
            # Save state and actions as NPZ
            np.savez_compressed(
                debug_dir / f"debug_step_{step}.npz",
                optimal_state=np.array(optimal_observation.state),
                suboptimal_state=np.array(suboptimal_observation.state),
                optimal_actions=np.array(optimal_actions),
                suboptimal_actions=np.array(suboptimal_actions),
                weight=np.array(get_weight(suboptimal_observation, suboptimal_actions)),
                step=step
            )
            
            assert False

        if step % config.log_interval == 0:
            stacked_infos = common_utils.stack_forest(infos)
            reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
            info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
            pbar.write(f"Step {step}: {info_str}")
            wandb.log(reduced_info, step=step)
            infos = []
        optimal_batch = next(optimal_data_iter)
        suboptimal_batch = next(suboptimal_data_iter)

        if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
            _checkpoints.save_state(checkpoint_manager, train_state, optimal_data_loader, step)

    logging.info("Waiting for checkpoint manager to finish")
    checkpoint_manager.wait_until_finished()


if __name__ == "__main__":
    config, value_config = _config.cli_awr()
    main(config, value_config) 

