from flax import nnx
import jax.random
import numpy as np

from offline import base
from offline.modules.base import TrainState, get_optimizer
from offline.diffusion.repaint.ddim.arguments import (
    Arguments,
    build_argument_parser,
)
from offline.diffusion.repaint.ddim.modules import RepaintDDIMPolicy
from offline.diffusion.repaint.train import TrainerState, train_fn
from offline.types import OfflineDataWithInfos
from offline.utils.data import ArrayDataLoader
from offline.utils.logger import Logger
from offline.utils.nnx import default_nnx_rngs


def init_fn(args: Arguments, data: OfflineDataWithInfos) -> TrainerState:
    if args.unsquash:
        raise NotImplementedError()

    policy = RepaintDDIMPolicy(
        action_dim=data.data.actions.shape[1],
        beta_schedule=args.beta_schedule,
        clip_sample=args.clip_sample,
        diffusion_steps=args.diffusion_steps,
        eta=args.eta,
        hidden_features=args.hidden_features,
        inference_steps=args.inference_steps,
        jump_length=args.jump_length,
        jump_samples=args.jump_samples,
        layer_norm=args.layer_norm,
        nonlinearity=args.nonlinearity,
        num_layers=args.num_layers,
        observation_dim=data.data.observations.shape[1],
        rngs=default_nnx_rngs(args.seed),
        temperature=args.temperature,
        time_dim=args.time_dim,
        timestep_spacing=args.timestep_spacing,
    )
    train_state = TrainState(
        model=policy,
        optimizer=get_optimizer(
            policy.diffusion.noise_predictor, learning_rate=args.learning_rate
        ),
    )
    graphdef, graphstate = nnx.split(train_state)
    train_noise_key, train_time_key = jax.random.split(
        jax.random.key(args.seed), 2
    )

    return TrainerState(
        eval_state=(jax.random.key(42), 0),
        data_iter=ArrayDataLoader(
            np.concatenate((data.data.actions, data.data.observations), axis=1),
            batch_size=args.batch_size,
            drop_last=True,
            rng=np.random.default_rng(args.seed),
        ).repeat_forever(),
        graphdef=graphdef,
        graphstate=graphstate,
        train_noise_key=train_noise_key,
        train_time_key=train_time_key,
    )


def load_fn(
    step: int | None,
    action_dim: int,
    beta_schedule: str,
    clip_sample: bool,
    diffusion_steps: int,
    eta: float,
    hidden_features: int,
    inference_steps: int,
    jump_length: int,
    jump_samples: int,
    layer_norm: bool,
    nonlinearity: str,
    num_layers: int,
    logger: Logger,
    observation_dim: int,
    temperature: float,
    time_dim: int,
    timestep_spacing: str,
    **kwargs,
):
    del kwargs

    def model_fn():
        return RepaintDDIMPolicy(
            action_dim=action_dim,
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            eta=eta,
            hidden_features=hidden_features,
            inference_steps=inference_steps,
            jump_length=jump_length,
            jump_samples=jump_samples,
            layer_norm=layer_norm,
            nonlinearity=nonlinearity,
            num_layers=num_layers,
            observation_dim=observation_dim,
            rngs=default_nnx_rngs(0),
            temperature=temperature,
            time_dim=time_dim,
            timestep_spacing=timestep_spacing,
        )

    policy = base.default_load_fn(logger=logger, model_fn=model_fn, step=step)
    return policy, (jax.random.key(42), 0)


if __name__ == "__main__":
    base.run(
        arguments_class=Arguments,
        init_fn=init_fn,
        train_fn=train_fn,
        **vars(build_argument_parser().parse_args()),
    )
