from typing import Callable

import haiku as hk
import jax
import jax.numpy as jnp
import optax
from flax import struct
from haiku import Params
from optax import OptState

from medium_rl.config import Config
from medium_rl.data.trajectory import make_dummy_subtrajectory, make_jit_trajectory_buffer
from medium_rl.data.transition import make_dummy_transition, make_jit_transition_buffer
from medium_rl.envs.amp import AMPSequence
from medium_rl.envs.bit_seq import BitSequence
from medium_rl.envs.gfp import GFPSequence
from medium_rl.envs.sequence_env import SequenceEnv, State
from medium_rl.envs.utr import UTRSequence
from medium_rl.network.multi_transformer import MultiTransformer


@struct.dataclass
class RunState:
    params: Params
    target_params: Params
    opt_state: OptState
    env_state: State
    rng: jnp.ndarray


def init_optimizer(base_lr: float, weight_decay: float, grad_clip_norm: float):
    return optax.chain(
        optax.clip_by_global_norm(grad_clip_norm),  # Add gradient clipping before the optimizers
        optax.adamw(base_lr, eps=1e-5, weight_decay=weight_decay),
    )


def init_run_state(env: SequenceEnv, forward: Callable, optimizer, cfg: Config):
    rng = jax.random.PRNGKey(cfg.seed)
    rng, params_rng, env_rng = jax.random.split(rng, 3)

    # Init params
    init_x = jnp.zeros((1, 1), dtype=jnp.int32)
    params = forward.init(params_rng, init_x)

    # Init states
    opt_state = optimizer.init(params=params)
    init_rng = jax.random.split(env_rng, cfg.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(init_rng)

    return RunState(params, params.copy(), opt_state, env_state, rng)


ENV_CLASSES = {"AMP": AMPSequence, "GFP": GFPSequence, "UTR": UTRSequence, "BitSeq": BitSequence}


def init_env(cfg: Config):
    env = ENV_CLASSES[cfg.env.name](**cfg.env.model_dump())
    env.reset_fn = jax.jit(jax.vmap(env.init))
    env.step_fn = jax.jit(jax.vmap(env.step))
    return env


def init_cfg(cfg: Config):
    env = init_env(cfg)

    # Forward
    network_cfg = cfg.network.model_dump()
    network_cfg["num_tokens"] = env.num_tokens
    network_cfg["pad_token"] = env.PAD

    if cfg.alg.name in ["TGM", "TGMW", "TGMP"]:
        network_cfg["networks"] = [env.num_tokens]
    elif cfg.alg.name == "SAC":
        network_cfg["networks"] = [env.num_tokens, env.num_tokens, env.num_tokens]
    elif cfg.alg.name == "PPO":
        network_cfg["networks"] = [env.num_tokens, 1]

    forward = hk.transform(lambda x, is_training=False: MultiTransformer(**network_cfg)(x, is_training))
    optimizer = init_optimizer(cfg.lr, cfg.weight_decay, cfg.grad_clip_norm)

    # Buffer
    if cfg.alg.data_type == "trajectory":
        buffer = make_jit_trajectory_buffer(
            max_length=cfg.replay_buffer_size,
            min_length=cfg.minibatch_size,
            sample_batch_size=cfg.minibatch_size,
            add_batches=True,
        )
        buffer_state = buffer.init(make_dummy_subtrajectory(env))
    elif cfg.alg.data_type == "transition":
        buffer = make_jit_transition_buffer(
            max_length=cfg.replay_buffer_size,
            min_length=cfg.minibatch_size,
            sample_batch_size=cfg.minibatch_size,
            add_batches=True,
        )
        buffer_state = buffer.init(make_dummy_transition(env))

    run_state = init_run_state(env, forward, optimizer, cfg)
    return run_state, env, forward, optimizer, buffer, buffer_state, network_cfg
