from copy import deepcopy

import torch as t

from .loss_functions import LossFunctions
from ..torch_utils import merge, select, update_state_dict


class ModelBasedJointTrainer(LossFunctions):
    def __init__(self, args):
        super().__init__(args)

        self.tau = 0.99
        self.target_encoder = deepcopy(self.solver.f_encoder)
        for p in self.target_encoder.parameters():
            p.requires_grad = False

    def compute_loss(
        self,
        states,
        actions,
        rewards,
        values,
    ):
        assert states.shape[1] > 1, "--n-steps must be >= 1 for model-based trainer"

        states_t0 = select(states, index=0)
        actions_t0 = select(actions, index=0)
        rewards_t0 = select(rewards, index=0)
        values_t0 = select(values, index=0)

        # Compute state encodings
        state_encodings_t0 = self.solver.f_encoder(states_t0)

        # Compute Q-values
        q_values_t0 = self.solver(
            state_encodings_t0,
            encoded=True,
            training=True,
        )

        # Compute BCL loss
        ood_loss = self.compute_ood_loss(
            q_values_t0,
            actions_t0,
        )

        # Compute TD loss
        td_loss = self.compute_td_loss(
            select(q_values_t0, actions_t0),
            values_t0,
        )

        # Compute reward loss on latent state encodings
        reward_loss = self.compute_reward_loss(
            self.solver.f_reward(state_encodings_t0, actions_t0),
            rewards_t0,
        )

        # Compute latent state encodings using the transition model
        latent_state_encodings = self.predict_latent_state_encodings(
            state_encodings_t0,
            actions,
        )

        # Merge batch and time dimensions into the batch dimension for parallel computation
        states = merge(states, dims=(0, 1))
        latent_state_encodings = merge(latent_state_encodings, dims=(0, 1))

        # Compute transition consistency loss
        with t.no_grad():
            update_state_dict(
                self.target_encoder,
                self.solver.f_encoder.state_dict(),
                tau=1 - self.tau,
            )
            target_state_encodings = self.target_encoder(states)

        transition_loss = self.compute_transition_loss(
            latent_state_encodings,
            target_state_encodings,
        )

        # Compute tree search policy loss if solver is tree-based
        tree_search_policy_loss = self.compute_tree_search_policy_loss(
            actions_t0,
            values_t0,
        )

        # Compute Accuracy
        accuracy = self.compute_accuracy(
            q_values_t0,
            actions_t0,
        )

        return (
            self.td_loss_weight * td_loss
            + self.ood_loss_weight * ood_loss
            + self.reward_loss_weight * reward_loss
            + self.transition_loss_weight * transition_loss
            + self.tree_loss_weight * tree_search_policy_loss,
            {
                "loss/td": td_loss.item(),
                "loss/ood": ood_loss.item(),
                "loss/reward": reward_loss.item(),
                "loss/transition": transition_loss.item(),
                "loss/tree": tree_search_policy_loss.item(),
                "accuracy/policy": accuracy.item(),
            },
        )
