import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


class SoftQNetwork(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.fc1 = nn.Linear(obs_size + act_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


LOG_STD_MAX = 2
LOG_STD_MIN = -5
class Actor(nn.Module):
    def __init__(self, obs_size, act_size):
        super().__init__()
        self.fc1 = nn.Linear(obs_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, act_size)
        self.fc_logstd = nn.Linear(256, act_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        return mean, log_std


class SACAgent(nn.Module):
    def __init__(self, obs_size, act_space, args):
        super().__init__()
        act_size = np.prod(act_space.shape)
        self.actor = Actor(obs_size, act_size).to(args.device)
        self.qf1 = SoftQNetwork(obs_size, act_size).to(args.device)
        self.qf2 = SoftQNetwork(obs_size, act_size).to(args.device)
        self.qf1_target = SoftQNetwork(obs_size, act_size).to(args.device)
        self.qf2_target = SoftQNetwork(obs_size, act_size).to(args.device)
        self.qf1_target.load_state_dict(self.qf1.state_dict())
        self.qf2_target.load_state_dict(self.qf2.state_dict())

        if args.autotune:
            self.target_entropy = torch.tensor(-act_size, device=args.device, dtype=torch.float32)
            self.log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
            self.alpha = self.log_alpha.exp().item()
        else:
            self.alpha = args.alpha
