from .loss_functions import LossFunctions
from ..torch_utils import select


class ModelFreeTrainer(LossFunctions):
    def compute_loss(
        self,
        states,
        actions,
        rewards,
        values,
    ):
        states_t0 = select(states, index=0)
        actions_t0 = select(actions, index=0)
        values_t0 = select(values, index=0)

        # Compute state encodings
        state_encodings_t0 = self.solver.f_encoder(states_t0)

        # Compute Q-values
        q_values_t0 = self.solver(
            state_encodings_t0,
            encoded=True,
            training=True,
        )

        # Compute TD loss
        td_loss = self.compute_td_loss(
            select(q_values_t0, actions_t0),
            values_t0,
        )

        # Compute OOD actions loss
        ood_loss = self.compute_ood_loss(
            q_values_t0,
            actions_t0,
        )

        # Compute tree search policy loss if solver is tree-based
        tree_search_policy_loss = self.compute_tree_search_policy_loss(
            actions_t0,
            values_t0,
        )

        # Compute Accuracy
        accuracy = self.compute_accuracy(
            q_values_t0,
            actions_t0,
        )

        return (
            self.td_loss_weight * td_loss
            + self.ood_loss_weight * ood_loss
            + self.tree_loss_weight * tree_search_policy_loss,
            {
                "loss/td": td_loss.item(),
                "loss/ood": ood_loss.item(),
                "loss/tree": tree_search_policy_loss.item(),
                "accuracy/policy": accuracy.item(),
            },
        )
