from collections.abc import Iterator
from dataclasses import dataclass, replace
from typing import TypeVar
from warnings import catch_warnings, filterwarnings

from flax import nnx
from jax import Array
from jax.numpy import ComplexWarning
import numpy as np

from offline import base
from offline.modules.actor.ensemble import GaussianActorEnsemble
from offline.modules.base import TargetModel, TrainState, TrainStateWithTarget
from offline.modules.mlp import MLP
from offline.lbp.tc.arguments import Arguments
from offline.lbp.tc.core import (
    behavior_cloning_step,
    compute_batch_embeddings,
    pretrain_critic_step,
    pretrain_critic_step_with_target_update,
    train_actor_critic_step,
    train_classifier_step,
    train_critic_step,
    train_tcae_step,
    v_learning_step,
)
from offline.lbp.tc.modules import LBPTCPolicy, LBPTCTrainState, QCriticEnsemble
from offline.lbp.tc.modules.auto_encoder import TCAutoEncoder
from offline.lbp.tc.types import (
    ActorBatch,
    AssignmentBatch,
    QLearningBatch,
    RegularizerBatch,
    SaBatch,
    TCBatch,
    VLearningBatch,
)
from offline.types import IntArray
from offline.utils.data import TrajectoryDataLoader, TrajectoryDataset
from offline.utils.logger import Logger
from offline.utils.tqdm import tqdm, trange


@dataclass(frozen=True)
class TrainerState(base.TrainerState[None]):
    actor_iter: Iterator[ActorBatch]
    graphdef: nnx.GraphDef[LBPTCTrainState]
    graphstate: nnx.GraphState | nnx.VariableState
    min_target: float
    ood_threshold: float
    q_learning_iter: Iterator[QLearningBatch]
    regularizer_iter: Iterator[RegularizerBatch]
    train_actor_key: Array
    train_critic_key: Array
    train_regularizer_key_noise: Array
    train_regularizer_key_policy: Array

    @property
    def policy(self) -> LBPTCPolicy:
        train_state = nnx.merge(self.graphdef, self.graphstate)
        return train_state.policy


T = TypeVar("T", bound=TrainerState)


def compute_assignments(
    dataset: TrajectoryDataset[TCBatch],
    tc_autoencoder: TCAutoEncoder,
    batch_size: int = 8,
):
    indices_list = []
    tc_autoencoder.eval(update_ema=False)
    graphdef, graphstate = nnx.split(tc_autoencoder)
    loader = TrajectoryDataLoader(
        dataset, batch_size=batch_size, drop_last=False
    )
    with tqdm(desc="Assign", leave=False, total=len(dataset)) as progress_bar:
        for batch, lengths in loader:
            results = compute_batch_embeddings(
                graphdef=graphdef,
                graphstate=graphstate,
                lengths=np.asarray(lengths),
                observations=batch.observations,
                rewards=batch.rewards,
            )
            indices_list.append(
                np.repeat(results.encoding_indices[:, 0], lengths)
            )
            progress_bar.update(batch.observations.shape[0])
    indices = np.concatenate(indices_list)
    return indices


def filter_codebook(
    dataset: TrajectoryDataset[TCBatch],
    tc_autoencoder: TCAutoEncoder,
    threshold: float,
    batch_size: int = 8,
):
    assignments = compute_assignments(
        batch_size=batch_size, dataset=dataset, tc_autoencoder=tc_autoencoder
    )
    vq = tc_autoencoder.vector_quantizer
    counts = np.bincount(assignments, minlength=vq.embeddings.shape[1])
    mask = counts >= np.sum(counts) * threshold
    vq.embeddings.value = vq.embeddings[:, mask]
    vq.ema_cluster_size.values.value = vq.ema_cluster_size.values[mask]
    vq.ema_dw.values.value = vq.ema_dw.values[:, mask]


def sample_indices(count: int, lengths: IntArray, rng: np.random.Generator):
    return rng.integers(lengths, size=(lengths.shape[0], count))


def behavior_cloning_fn(
    actor: GaussianActorEnsemble,
    logger: Logger,
    optimizer: nnx.Optimizer,
    sa_iter: Iterator[SaBatch],
    steps: int,
) -> GaussianActorEnsemble:
    graphdef, graphstate = nnx.split(
        TrainState(model=actor, optimizer=optimizer)
    )
    for step in trange(steps, desc="BC"):
        batch = next(sa_iter)
        graphstate, results = behavior_cloning_step(
            actions=batch.actions,
            assignments=batch.assignments,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def pretrain_critic_fn(
    gamma: float,
    logger: Logger,
    optimizer: nnx.Optimizer,
    qcritic: QCriticEnsemble,
    q_learning_iter: Iterator[QLearningBatch],
    steps: int,
    tau: float,
    update_every: int,
) -> QCriticEnsemble:
    graphdef, graphstate = nnx.split(
        TrainStateWithTarget(
            model=qcritic,
            optimizer=optimizer,
            target=TargetModel(qcritic),
        )
    )
    for step in trange(steps, desc="Pretrain"):
        batch = next(q_learning_iter)
        if (step + 1) % update_every == 0:
            graphstate, results = pretrain_critic_step_with_target_update(
                actions=batch.actions,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                mask=batch.mask,
                next_mask=batch.next_mask,
                next_means=batch.next_means,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
                tau=tau,
            )
        else:
            graphstate, results = pretrain_critic_step(
                actions=batch.actions,
                dones=batch.dones,
                gamma=gamma,
                graphdef=graphdef,
                graphstate=graphstate,
                mask=batch.mask,
                next_mask=batch.next_mask,
                next_means=batch.next_means,
                next_observations=batch.next_observations,
                observations=batch.observations,
                rewards=batch.rewards,
            )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def _train_tcae_step(
    batch: TCBatch,
    commitment_cost: float,
    graphdef: nnx.GraphDef[TrainState[TCAutoEncoder]],
    graphstate: nnx.GraphState | nnx.VariableState,
    lengths: IntArray,
    reward_weight: float,
    rng: np.random.Generator,
    subsample_decodes: int,
    subsample_latents: int,
    transition_weight: float,
):
    lengths = np.expand_dims(lengths, 1)
    decode_indices = sample_indices(
        count=subsample_decodes, lengths=lengths, rng=rng
    )
    graphstate, results = train_tcae_step(
        actions=batch.actions,
        commitment_cost=commitment_cost,
        decode_indices=decode_indices,
        decode_mask=decode_indices < lengths - 1,
        graphdef=graphdef,
        graphstate=graphstate,
        latent_indices=sample_indices(
            count=subsample_latents, lengths=lengths, rng=rng
        ),
        observations=batch.observations,
        reward_weight=reward_weight,
        rewards=batch.rewards,
        transition_weight=transition_weight,
    )
    return graphstate, results


def train_classifier_fn(
    classifier: MLP,
    classifier_iter: Iterator[AssignmentBatch],
    logger: Logger,
    optimizer: nnx.Optimizer,
    steps: int,
):
    train_state = TrainState(model=classifier, optimizer=optimizer)
    graphdef, graphstate = nnx.split(train_state)
    for step in trange(steps, desc="CLS"):
        batch = next(classifier_iter)
        graphstate, results = train_classifier_step(
            assignments=batch.assignments,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def train_fn(step: int, args: Arguments, state: T) -> T:
    critic_batch = next(state.q_learning_iter)
    regularizer_batch = next(state.regularizer_iter)
    noise_batch_size = int(round(args.batch_size * args.noise_ratio))
    if (step + 1) % args.update_every == 0:
        actor_batch = next(state.actor_iter)
        graphstate, results = train_actor_critic_step(
            actions=critic_batch.actions,
            assignments_regularizer=regularizer_batch.assignments,
            baseline_regularizer=regularizer_batch.baseline,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            lipschitz_constant=args.lipschitz_constant,
            mask=critic_batch.mask,
            mask_actor=actor_batch.mask,
            mask_regularizer=regularizer_batch.mask,
            means_actor=actor_batch.means,
            means_regularizer=regularizer_batch.means,
            min_target=state.min_target,
            next_mask=critic_batch.next_mask,
            next_means=critic_batch.next_means,
            next_observations=critic_batch.next_observations,
            noise_batch_size=noise_batch_size,
            observations=critic_batch.observations,
            observations_actor=actor_batch.observations,
            observations_regularizer=regularizer_batch.observations,
            ood_threshold=state.ood_threshold,
            regularizer_weight=args.regularizer_weight,
            rewards=critic_batch.rewards,
            std_multiplier=args.std_multiplier,
            stds_regularizer=regularizer_batch.stds,
            step=step,
            tau=args.tau,
            train_actor_key=state.train_actor_key,
            train_critic_key=state.train_critic_key,
            train_regularizer_key_noise=state.train_regularizer_key_noise,
            train_regularizer_key_policy=state.train_regularizer_key_policy,
        )
    else:
        graphstate, results = train_critic_step(
            actions=critic_batch.actions,
            assignments_regularizer=regularizer_batch.assignments,
            baseline_regularizer=regularizer_batch.baseline,
            dones=critic_batch.dones,
            gamma=args.gamma,
            graphdef=state.graphdef,
            graphstate=state.graphstate,
            lipschitz_constant=args.lipschitz_constant,
            mask=critic_batch.mask,
            mask_regularizer=regularizer_batch.mask,
            means_regularizer=regularizer_batch.means,
            min_target=state.min_target,
            next_mask=critic_batch.next_mask,
            next_means=critic_batch.next_means,
            next_observations=critic_batch.next_observations,
            noise_batch_size=noise_batch_size,
            observations=critic_batch.observations,
            observations_regularizer=regularizer_batch.observations,
            ood_threshold=state.ood_threshold,
            regularizer_weight=args.regularizer_weight,
            rewards=critic_batch.rewards,
            std_multiplier=args.std_multiplier,
            stds_regularizer=regularizer_batch.stds,
            step=step,
            train_critic_key=state.train_critic_key,
            train_regularizer_key_noise=state.train_regularizer_key_noise,
            train_regularizer_key_policy=state.train_regularizer_key_policy,
        )
    args.logger.write(step, **results)
    return replace(state, graphstate=graphstate)


def train_tcae_fn(
    batch_size: int,
    commitment_cost: float,
    dataset: TrajectoryDataset[TCBatch],
    logger: Logger,
    optimizer: nnx.Optimizer,
    phase_split: float,
    reward_weight: float,
    reweight: bool,
    rng: np.random.Generator,
    steps: int,
    subsample_decodes: int,
    subsample_latents: int,
    tc_autoencoder: TCAutoEncoder,
    threshold: float,
    transition_weight: float,
) -> TCAutoEncoder:
    iterator = TrajectoryDataLoader(
        dataset,
        batch_size=batch_size,
        drop_last=True,
        reweight=reweight,
        rng=rng,
    ).repeat_forever()
    tc_autoencoder.train(update_ema=True)
    train_state = TrainState(model=tc_autoencoder, optimizer=optimizer)
    graphdef, graphstate = nnx.split(train_state)
    steps0 = round(steps * phase_split)
    with catch_warnings():
        filterwarnings("ignore", category=ComplexWarning)
        if steps0 > 0:
            for step in trange(steps0, desc="TC-Warmup"):
                batch, lengths = next(iterator)
                graphstate, results = _train_tcae_step(
                    batch=batch,
                    commitment_cost=commitment_cost,
                    graphdef=graphdef,
                    graphstate=graphstate,
                    lengths=np.asarray(lengths),
                    reward_weight=reward_weight,
                    rng=rng,
                    subsample_decodes=subsample_decodes,
                    subsample_latents=subsample_latents,
                    transition_weight=transition_weight,
                )
                logger.write(step, **results)
            train_state = nnx.merge(graphdef, graphstate)
            filter_codebook(
                dataset=dataset,
                tc_autoencoder=train_state.model,
                threshold=threshold,
            )
            train_state.model.train(update_ema=True)
            graphdef, graphstate = nnx.split(train_state)
        for step in trange(steps0, steps, desc="TC"):
            batch, lengths = next(iterator)
            graphstate, results = _train_tcae_step(
                batch=batch,
                commitment_cost=commitment_cost,
                graphdef=graphdef,
                graphstate=graphstate,
                lengths=np.asarray(lengths),
                reward_weight=reward_weight,
                rng=rng,
                subsample_decodes=subsample_decodes,
                subsample_latents=subsample_latents,
                transition_weight=transition_weight,
            )
            logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model


def v_learning_fn(
    logger: Logger,
    optimizer: nnx.Optimizer,
    steps: int,
    v_learning_iter: Iterator[VLearningBatch],
    vcritic: MLP,
) -> MLP:
    graphdef, graphstate = nnx.split(
        TrainState(model=vcritic, optimizer=optimizer)
    )
    for step in trange(steps, desc="V"):
        batch = next(v_learning_iter)
        graphstate, results = v_learning_step(
            assignments=batch.assignments,
            graphdef=graphdef,
            graphstate=graphstate,
            observations=batch.observations,
            targets=batch.targets,
        )
        logger.write(step, **results)
    train_state = nnx.merge(graphdef, graphstate)
    return train_state.model
