import torch
import torch.utils.data

from .networks_pytorch import TanhNormalPolicy

class BehaviorPolicy(TanhNormalPolicy):
    def __init__(self, *args, lr, **kwargs):
        super().__init__(*args, **kwargs)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
    
    def update(self, offline_loader):
        self.train()

        loss = 0
        n = 0
        for batch in offline_loader:
            
            actions = batch['actions']
            observations = batch['observations']
            expert_state = expert_state[0].to(self.device)
            offline_state = offline_state[0][:expert_state.shape[0]].to(self.device)

            (sampled_action, sampled_pretanh_action, sampled_action_log_prob, sampled_pretanh_action_log_prob, pretanh_action_dist), _ \
                    = self._policy_network((observations,))
            action_log_prob, _ = self.log_prob(pretanh_action_dist, action=actions, is_pretanh_action=False)
            
            self.optimizer.zero_grad()
            (-action_log_prob.mean()).backward()
            self.optimizer.step()
        return loss / n

