import torch
import torch.nn.functional as F
from replay_memory import ReplayMemory
from network import Twin_Q_net, GaussianPolicy
from temporary_buffer import TemporaryBuffer
from utils import hard_update, soft_update


class BPQLAgent:  # SAC for the base learning algorithm
    def __init__(self, args, state_dim, action_dim, action_bound, action_space, device):
        self.args = args
        self.device = device

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_bound = action_bound

        # Make Replay buffer D, Temporal buffer B
        self.replay_memory = ReplayMemory(args.init_obs_delayed_steps, state_dim, action_dim, device, args.buffer_size)
        self.temporary_buffer = TemporaryBuffer(args.init_obs_delayed_steps)
        self.eval_temporary_buffer = TemporaryBuffer(args.init_obs_delayed_steps)

        self.batch_size = args.batch_size
        self.gamma = args.gamma
        self.xi = args.xi

        self.actor = GaussianPolicy(args, args.init_obs_delayed_steps, state_dim, action_dim, action_bound, args.hidden_dims, F.relu, device).to(device)
        self.critic = Twin_Q_net(state_dim, action_dim, device, args.hidden_dims).to(device)  # Network for the beta Q-values.
        self.target_critic = Twin_Q_net(state_dim, action_dim, device, args.hidden_dims).to(device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr)

        # Automated Entropy Adjustment for Maximum Entropy
        if args.automating_temperature is True:
            self.target_entropy = -torch.prod(torch.Tensor(action_space.shape)).to(device)
            self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.temperature_lr)
        else:
            self.log_alpha = torch.log(torch.tensor(args.temperature, device=device, dtype=torch.float32))

        hard_update(self.critic, self.target_critic)

    def get_action(self, state, evaluation=True):
        with torch.no_grad():
            if evaluation:
                _, _, action = self.actor.sample(state)
            else:
                action, _, _ = self.actor.sample(state)
        return action.cpu().numpy()[0]

    def train_actor(self, augmented_states, states, train_alpha=True):
        self.actor_optimizer.zero_grad()
        actions, log_pis, _ = self.actor.sample(augmented_states)
        q_values_A, q_values_B = self.critic(states, actions)
        q_values = torch.min(q_values_A, q_values_B)

        actor_loss = (self.log_alpha.exp().detach() * log_pis - q_values).mean()
        actor_loss.backward()
        self.actor_optimizer.step()

        if train_alpha:
            self.alpha_optimizer.zero_grad()
            alpha_loss = -(self.log_alpha.exp() * (log_pis + self.target_entropy).detach()).mean()
            alpha_loss.backward()
            self.alpha_optimizer.step()
        else:
            alpha_loss = torch.tensor(0.)

    def train_critic(self, actions, rewards, next_augmented_states, dones,  states, next_states):
        self.critic_optimizer.zero_grad()
        with torch.no_grad():
            next_actions, next_log_pis, _ = self.actor.sample(next_augmented_states)
            next_q_values_A, next_q_values_B = self.target_critic(next_states, next_actions)
            next_q_values = torch.min(next_q_values_A, next_q_values_B) - self.log_alpha.exp() * next_log_pis
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        q_values_A, q_values_B = self.critic(states, actions)
        critic_loss = ((q_values_A - target_q_values)**2).mean() + ((q_values_B - target_q_values)**2).mean()

        critic_loss.backward()
        self.critic_optimizer.step()

    def train(self):
        augmented_states, actions, rewards, next_augmented_states, dones, states, next_states = self.replay_memory.sample(self.batch_size)

        # train critic
        self.train_critic(actions, rewards, next_augmented_states, dones, states, next_states)

        # train actor
        if self.args.automating_temperature is True:
            self.train_actor(augmented_states, states, train_alpha=True)
        else:
            self.train_actor(augmented_states, states, train_alpha=False)

        # target critic update
        soft_update(self.critic, self.target_critic, self.xi)
