from torch import nn as nn


class BehaviorClone(nn.Module):
    def __init__(
        self,
        model,
        reward_model,
        transition_model,
        observation_dim,
        action_dim,
    ):
        super().__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.model = model
        self.reward_model = reward_model
        self.transition_model = transition_model

    def loss(self, obs, acts, rews):
        log_prob = self.model.get_log_prob(obs, acts).mean()
        loss = -1.0 * log_prob
        
        predicted_reward = self.reward_model(obs, acts)
        predicted_next_obs = self.transition_model(obs[:, :-1], acts[:, :-1])

        loss_fn = nn.MSELoss()
        reward_loss = loss_fn(rews.mean(dim=2, keepdim=True), predicted_reward)
        transition_loss = loss_fn(obs[:, 1:], predicted_next_obs)
        
        info = dict(bc_loss=loss, reward_loss=reward_loss, transition_loss = transition_loss, log_prob=log_prob)
        tot_loss = loss + reward_loss + transition_loss
        
        return tot_loss, info

    def forward(self, obs, deterministic=False):
        return self.model(obs, deterministic=deterministic)[0]
