from .loss_functions import LossFunctions
from ..torch_utils import merge, select


class TreeQNTrainer(LossFunctions):
    def compute_loss(
        self,
        states,
        actions,
        rewards,
        values,
    ):
        assert states.shape[1] > 1, "--n-steps must be >= 1 for model-based trainer"

        states_t0 = select(states, index=0)
        actions_t0 = select(actions, index=0)
        values_t0 = select(values, index=0)

        # Compute state encodings
        state_encodings_t0 = self.solver.f_encoder(states_t0)

        # Compute Q-values
        q_values_t0 = self.solver(state_encodings_t0, encoded=True)

        # Compute TD loss
        td_loss = self.compute_td_loss(
            select(q_values_t0, actions_t0),
            values_t0,
        )

        # Compute OOD loss
        ood_loss = self.compute_ood_loss(
            q_values_t0,
            actions_t0,
        )

        # Compute latent state encodings using the transition model
        latent_state_encodings = self.predict_latent_state_encodings(
            state_encodings_t0,
            actions,
        )

        # Merge batch and time dimensions into the batch dimension for parallel computation
        actions = merge(actions, dims=(0, 1))
        rewards = merge(rewards, dims=(0, 1))
        latent_state_encodings = merge(latent_state_encodings, dims=(0, 1))

        # Compute reward loss on latent state encodings
        latent_reward_loss = self.compute_reward_loss(
            self.solver.f_reward(latent_state_encodings, actions),
            rewards,
        )

        # Compute Accuracy
        accuracy = self.compute_accuracy(
            q_values_t0,
            actions_t0,
        )

        return (
            self.td_loss_weight * td_loss
            + self.ood_loss_weight * ood_loss
            + self.reward_loss_weight * latent_reward_loss,
            {
                "loss/td": td_loss.item(),
                "loss/ood": ood_loss.item(),
                "loss/reward": latent_reward_loss.item(),
                "accuracy/policy": accuracy.item(),
            },
        )
