import copy
import math

import torch
from torch import nn

from .config import BaseConfig, Configurable, Optional
from .defaults import ACTOR_LR, CRITIC_LR
from .policy import BasePolicy, SquashedGaussianPolicy
from .torch_util import device, Module, mlp, update_ema, freeze_module


class TwinCritic(Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, hidden_layers=2):
        super().__init__()
        dims = [state_dim + action_dim, *([hidden_dim] * hidden_layers), 1]
        self.q1 = mlp(dims, squeeze_output=True)
        self.q2 = mlp(dims, squeeze_output=True)

    def both(self, state, action):
        sa = torch.cat([state, action], 1)
        return self.q1(sa), self.q2(sa)

    def min(self, state, action):
        q1, q2 = self.both(state, action)
        return torch.min(q1, q2)


class ActorCritic(Configurable, Module, BasePolicy):
    class Config(BaseConfig):
        actor_lr = ACTOR_LR
        critic_lr = CRITIC_LR

    def __init__(self, config, actor, critic,
                 optimizer_factory=torch.optim.Adam,
                 use_actor_target=False, use_critic_target=False):
        Configurable.__init__(self, config)
        Module.__init__(self)
        self.actor = actor
        self.critic = critic

        if isinstance(actor, nn.Module):
            self.actor_optimizer = optimizer_factory(actor.parameters(), lr=self.actor_lr)
        if isinstance(critic, nn.Module):
            self.critic_optimizer = optimizer_factory(critic.parameters(), lr=self.critic_lr)

        if use_actor_target:
            self.actor_target = copy.deepcopy(actor)
            freeze_module(self.actor_target)
        else:
            self.actor_target = None

        if use_critic_target:
            self.critic_target = copy.deepcopy(critic)
            freeze_module(self.critic_target)
        else:
            self.critic_target = None

    def act(self, states, eval):
        return self.actor.act(states, eval)


class SAC(ActorCritic):
    class Config(ActorCritic.Config):
        discount = 0.99
        init_alpha = 1.0
        autotune_alpha = True
        target_entropy = Optional(float)
        use_log_alpha_loss = True
        deterministic_backup = False
        critic_update_multiplier = 1
        tau = 0.005
        batch_size = 256
        hidden_dim = 256
        hidden_layers = 2

    def __init__(self, config, state_dim, action_dim,
                 actor=None, critic=None, log_alpha=None,   # for sharing across instances
                 q_bounds=None,     # to constrain the range of the learned Q function
                 criterion=None,
                 optimizer_factory=torch.optim.Adam):
        Configurable.__init__(self, config)
        if actor is None:
            actor_net = mlp([state_dim, *([self.hidden_dim] * self.hidden_layers), action_dim*2])
            actor = SquashedGaussianPolicy(actor_net)
        if critic is None:
            critic = TwinCritic(state_dim, action_dim)
        super().__init__(config, actor, critic,
                         optimizer_factory=optimizer_factory,
                         use_actor_target=False, use_critic_target=True)
        if log_alpha is None:
            log_alpha = torch.tensor(math.log(self.init_alpha), device=device, requires_grad=True)
        self.log_alpha = log_alpha
        if self.autotune_alpha:
            self.alpha_optimizer = optimizer_factory([self.log_alpha], lr=self.actor_lr)
        if self.target_entropy is None:
            self.target_entropy = -action_dim   # set target entropy to -dim(A)
        if q_bounds is not None:
            assert len(q_bounds) == 2
        self.q_bounds = q_bounds
        self.criterion = nn.MSELoss() if criterion is None else criterion
        self.register_buffer('total_updates', torch.zeros([]))

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def critic_value(self, obs, action, target):
        critic = self.critic_target if target else self.critic
        q = critic.min(obs, action)
        if target and self.q_bounds is not None:
            q_min, q_max = self.q_bounds
            return q.clamp(q_min, q_max)
        else:
            return q

    def compute_target(self, next_obs, reward, done):
        with torch.no_grad():
            distr = self.actor.distr(next_obs)
            next_action = distr.sample()
            log_prob = distr.log_prob(next_action)
            next_value = self.critic_value(next_obs, next_action, target=True)
            if not self.deterministic_backup:
                next_value = next_value - self.alpha.detach() * log_prob
            return reward + (1. - done.float()) * self.discount * next_value

    def critic_loss_given_target(self, obs, action, target):
        Q1, Q2 = self.critic.both(obs, action)
        return self.criterion(Q1, target) + self.criterion(Q2, target)

    def critic_loss(self, obs, action, next_obs, reward, done):
        target = self.compute_target(next_obs, reward, done)
        return self.critic_loss_given_target(obs, action, target)

    def update_critic(self, *critic_loss_args):
        critic_loss = self.critic_loss(*critic_loss_args)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        update_ema(self.critic_target, self.critic, self.tau)

    def actor_loss(self, obs, include_alpha=True):
        distr = self.actor.distr(obs)
        action = distr.rsample()
        log_prob = distr.log_prob(action)
        actor_Q = self.critic_value(obs, action, target=False)
        alpha = self.alpha
        actor_loss = torch.mean(alpha.detach() * log_prob - actor_Q)
        if include_alpha:
            multiplier = self.log_alpha if self.use_log_alpha_loss else alpha
            alpha_loss = -multiplier * torch.mean(log_prob.detach() + self.target_entropy)
            return [actor_loss, alpha_loss]
        else:
            return [actor_loss]

    def update_actor_and_alpha(self, obs):
        losses = self.actor_loss(obs, include_alpha=self.autotune_alpha)
        optimizers = [self.actor_optimizer, self.alpha_optimizer] if self.autotune_alpha else \
                     [self.actor_optimizer]
        assert len(losses) == len(optimizers)
        for loss, optimizer in zip(losses, optimizers):
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    def update(self, replay_buffer):
        assert self.critic_update_multiplier >= 1
        for _ in range(self.critic_update_multiplier):
            samples = replay_buffer.sample(self.batch_size)
            self.update_critic(*samples)
        self.update_actor_and_alpha(samples[0])
        self.total_updates += 1