import os
import torch
from torch.optim import Adam, RMSprop
import torch.nn.functional as F

from .base import Algorithm
from main.network import TwinnedStateActionFunction, GaussianPolicy
from main.utils import disable_gradients, soft_update, update_params, \
    assert_action


class SAC(Algorithm):

    def __init__(self, state_dim, action_dim, device, gamma=0.99, nstep=1,
                 policy_lr=0.0003, q_lr=0.0003, entropy_lr=0.0003,
                 policy_hidden_units=[256, 256], q_hidden_units=[256, 256],
                 target_update_coef=0.005, log_interval=10, seed=0, frac_lr=2.5e-9):
        super().__init__(
            state_dim, action_dim, device, gamma, nstep, log_interval, seed)

        # Build networks.
        self._policy_net = GaussianPolicy(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_units=policy_hidden_units
            ).to(self._device)
        self._online_q_net = TwinnedStateActionFunction(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_units=q_hidden_units
            ).to(self._device)
        self._target_q_net = TwinnedStateActionFunction(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_units=q_hidden_units
            ).to(self._device).eval()

        self.num_quant = 32

        # Copy parameters of the learning network to the target network.
        self._target_q_net.load_state_dict(self._online_q_net.state_dict())

        # Disable gradient calculations of the target network.
        disable_gradients(self._target_q_net)

        # Optimizers.
        self._policy_optim = Adam(self._policy_net.parameters(), lr=policy_lr)
        quantile_param = list(self._online_q_net.net1.psi_layer.parameters())+list(self._online_q_net.net1.cosine_layer.parameters())+list(self._online_q_net.net1.quantile_layer.parameters())+list(self._online_q_net.net2.psi_layer.parameters())+list(self._online_q_net.net2.cosine_layer.parameters())+list(self._online_q_net.net2.quantile_layer.parameters())
        fraction_param = list(self._online_q_net.net1.fraction_prop_layer.parameters())+list(self._online_q_net.net2.fraction_prop_layer.parameters())
        self._q_optim = Adam(quantile_param, lr=q_lr)
        self._frac_optim = RMSprop(fraction_param, lr=frac_lr)

        # Target entropy is -|A|.
        self._target_entropy = -float(self._action_dim)

        # We optimize log(alpha), instead of alpha.
        self._log_alpha = torch.zeros(
            1, device=self._device, requires_grad=True)
        self._alpha = self._log_alpha.detach().exp()
        self._alpha_optim = Adam([self._log_alpha], lr=entropy_lr)

        self._target_update_coef = target_update_coef

    def explore(self, state):
        state = torch.tensor(
            state[None, ...].copy(), dtype=torch.float, device=self._device)
        with torch.no_grad():
            action, _, _ = self._policy_net(state)
        action = action.cpu().numpy()[0]
        assert_action(action)
        return action

    def exploit(self, state):
        state = torch.tensor(
            state[None, ...].copy(), dtype=torch.float, device=self._device)
        with torch.no_grad():
            _, _, action = self._policy_net(state)
        action = action.cpu().numpy()[0]
        assert_action(action)
        return action

    def update_target_networks(self):
        soft_update(
            self._target_q_net, self._online_q_net, self._target_update_coef)

    def update_online_networks(self, batch, writer):
        self._learning_steps += 1
        self.update_policy_and_entropy(batch, writer)
        self.update_q_functions(batch, writer)

    def update_policy_and_entropy(self, batch, writer):
        states, actions, rewards, next_states, dones = batch

        # Update policy.
        policy_loss, entropies = self.calc_policy_loss(states)
        update_params(self._policy_optim, policy_loss)

        # Update the entropy coefficient.
        entropy_loss = self.calc_entropy_loss(entropies)
        update_params(self._alpha_optim, entropy_loss)
        self._alpha = self._log_alpha.detach().exp()

        if self._learning_steps % self._log_interval == 0:
            writer.add_scalar(
                'loss/policy', policy_loss.detach().item(),
                self._learning_steps)
            writer.add_scalar(
                'loss/entropy', entropy_loss.detach().item(),
                self._learning_steps)
            writer.add_scalar(
                'stats/alpha', self._alpha.item(),
                self._learning_steps)
            writer.add_scalar(
                'stats/entropy', entropies.detach().mean().item(),
                self._learning_steps)

    def calc_policy_loss(self, states):
        # Resample actions to calculate expectations of Q.
        sampled_actions, entropies, _ = self._policy_net(states)

        # Expectations of Q with clipped double Q technique.
        sa_embedding1, taus1, tau_hats1, entropies1, sa_embedding2, taus2, tau_hats2, entropies2 = self._online_q_net(states, sampled_actions)
        qs1 = self._online_q_net.net1.calc_quantiles(tau_hats1, sa_embedding1)
        qs2 = self._online_q_net.net2.calc_quantiles(tau_hats2, sa_embedding2)
        
        qs1 = qs1.mean(dim=1).unsqueeze(1)
        qs2 = qs2.mean(dim=1).unsqueeze(1)
        qs = torch.min(qs1, qs2)

        # Policy objective is maximization of (Q + alpha * entropy).
        assert qs.shape == entropies.shape
        policy_loss = torch.mean((- qs - self._alpha * entropies))

        return policy_loss, entropies.detach_()

    def calc_entropy_loss(self, entropies):
        assert not entropies.requires_grad

        # Intuitively, we increse alpha when entropy is less than target
        # entropy, vice versa.
        entropy_loss = -torch.mean(
            self._log_alpha * (self._target_entropy - entropies))
        return entropy_loss

    def update_q_functions(self, batch, writer, imp_ws1=None, imp_ws2=None):
        states, actions, rewards, next_states, dones = batch

        # Calculate current and target Q values.
        sa_embedding1, taus1, tau_hats1, entropies1, sa_embedding2, taus2, tau_hats2, entropies2 = self._online_q_net(states, actions)
        curr_qs1 = self._online_q_net.net1.calc_quantiles(tau_hats1, sa_embedding1)
        curr_qs2 = self._online_q_net.net1.calc_quantiles(tau_hats2, sa_embedding2)
        
        fraction_loss = self.calc_fraction_loss(sa_embedding1.detach(), curr_qs1.detach(),taus1, sa_embedding2.detach(), curr_qs2.detach(),taus2)
        update_params(self._frac_optim, fraction_loss, retain_graph=True)

        target_qs = self.calc_target_qs(rewards, next_states, dones, tau_hats1, tau_hats2)

        # Update Q functions.
        q_loss, mean_q1, mean_q2 = \
            self.calc_q_loss(curr_qs1, curr_qs2, tau_hats1, tau_hats2, target_qs, imp_ws1, imp_ws2)
        update_params(self._q_optim, q_loss)

        if self._learning_steps % self._log_interval == 0:
            writer.add_scalar(
                'loss/Q', q_loss.detach().item(),
                self._learning_steps)
            writer.add_scalar(
                'stats/mean_Q1', mean_q1, self._learning_steps)
            writer.add_scalar(
                'stats/mean_Q2', mean_q2, self._learning_steps)

        # Return there values for DisCor algorithm.
        return curr_qs1.detach(), curr_qs2.detach(), target_qs


    def calc_target_qs(self, rewards, next_states, dones, tau_hats1, tau_hats2):
        assert not tau_hats1.requires_grad
        assert not tau_hats2.requires_grad

        with torch.no_grad():
            next_actions, next_entropies, _ = self._policy_net(next_states)
            
            sa_embedding1, _, _, _, sa_embedding2, _, _, _ = self._target_q_net(next_states, next_actions)
            next_qs1 = self._target_q_net.net1.calc_quantiles(tau_hats1, sa_embedding1)
            next_qs2 = self._target_q_net.net1.calc_quantiles(tau_hats1, sa_embedding1)

            batch_sz = next_qs1.size(0)
            next_qs = torch.stack((next_qs1, next_qs2), dim=1)
            ind = next_qs.mean(dim=2).min(1)[1]
            ind = ind.unsqueeze(1).unsqueeze(1).expand(batch_sz, 1, self.num_quant)
            next_qs = next_qs.gather(1, ind).squeeze(1)
            # next_qs = \
            #     torch.min(next_qs1, next_qs2) + self._alpha * next_entropies
            next_qs = next_qs + self._alpha * next_entropies

        #print(rewards.shape)
        #print(next_qs.shape)
        #assert rewards.shape == next_qs.shape
        target_qs = rewards + (1.0 - dones) * self._discount * next_qs

        return target_qs

    def calc_q_loss(self, curr_qs1, curr_qs2, curr_tau1, curr_tau2, target_qs, imp_ws1=None,
                    imp_ws2=None):
        assert imp_ws1 is None or imp_ws1.shape == curr_qs1.shape
        assert imp_ws2 is None or imp_ws2.shape == curr_qs2.shape
        assert not target_qs.requires_grad
        assert curr_qs1.shape == target_qs.shape

        # Q loss is mean squared TD errors with importance weights.
        num_support = self.num_quant
        T_theta_tile = target_qs.view(-1, num_support, 1).expand(-1, num_support, num_support) # target
        theta_a_tile1 = curr_qs1.view(-1, 1, num_support).expand(-1, num_support, num_support) # current

        #tau = torch.arange(0.5 * (1 / num_support), 1, 1 / num_support).view(1, num_support).to(self._device)
        tau = curr_tau1
        error_loss = T_theta_tile - theta_a_tile1 
     
        huber_loss = F.smooth_l1_loss(theta_a_tile1, T_theta_tile.detach(), reduction='none')
        value_loss = (tau[..., None] - (error_loss < 0).float()).abs() * huber_loss
        q1_loss = value_loss.mean(dim=2).sum(dim=1).mean()
        
        tau = curr_tau2
        theta_a_tile2 = curr_qs2.view(-1, 1, num_support).expand(-1, num_support, num_support) # current
        error_loss = T_theta_tile - theta_a_tile2           
        huber_loss = F.smooth_l1_loss(theta_a_tile2, T_theta_tile.detach(), reduction='none')
        value_loss = (tau[..., None] - (error_loss < 0).float()).abs() * huber_loss
        q2_loss = value_loss.mean(dim=2).sum(dim=1).mean()

        # if imp_ws1 is None:
        #     q1_loss = torch.mean((curr_qs1 - target_qs).pow(2))
        #     q2_loss = torch.mean((curr_qs2 - target_qs).pow(2))

        # else:
        #     q1_loss = torch.sum((curr_qs1 - target_qs).pow(2) * imp_ws1)
        #     q2_loss = torch.sum((curr_qs2 - target_qs).pow(2) * imp_ws2)

        # Mean Q values for logging.
        mean_q1 = curr_qs1.detach().mean(1).mean().item()
        mean_q2 = curr_qs2.detach().mean(1).mean().item()

        return q1_loss + q2_loss, mean_q1, mean_q2

    def calc_fraction_loss(self, sa_embedding1, sa_quantile_hats1, taus1, sa_embedding2, sa_quantile_hats2, taus2):
        assert not sa_embedding1.requires_grad
        assert not sa_quantile_hats1.requires_grad
        assert not sa_embedding2.requires_grad
        assert not sa_quantile_hats2.requires_grad

        batch_size = sa_embedding1.shape[0]
        N = 32
        with torch.no_grad():
            sa_quantiles1 = self._online_q_net.net1.calc_quantiles(taus1[:, 1:-1], sa_embedding1)
            sa_quantiles2 = self._online_q_net.net2.calc_quantiles(taus2[:, 1:-1], sa_embedding2)
            assert sa_quantiles1.shape == (batch_size, N-1)

        values_1 = sa_quantiles1 - sa_quantile_hats1[:, :-1]
        signs_1 = sa_quantiles1 > torch.cat([
            sa_quantile_hats1[:, :1], sa_quantiles1[:, :-1]], dim=1)
        assert values_1.shape == signs_1.shape

        values_2 = sa_quantiles1 - sa_quantile_hats1[:, 1:]
        signs_2 = sa_quantiles1 < torch.cat([
            sa_quantiles1[:, 1:], sa_quantile_hats1[:, -1:]], dim=1)
        assert values_2.shape == signs_2.shape

        gradient_of_taus1 = (
            torch.where(signs_1, values_1, -values_1)
            + torch.where(signs_2, values_2, -values_2)
        ).view(batch_size, N-1)
        assert not gradient_of_taus1.requires_grad
        assert gradient_of_taus1.shape == taus1[:, 1:-1].shape

        # ------
        values_1 = sa_quantiles2 - sa_quantile_hats2[:, :-1]
        signs_1 = sa_quantiles2 > torch.cat([
            sa_quantile_hats2[:, :1], sa_quantiles2[:, :-1]], dim=1)
        assert values_1.shape == signs_1.shape

        values_2 = sa_quantiles2 - sa_quantile_hats2[:, 1:]
        signs_2 = sa_quantiles2 < torch.cat([
            sa_quantiles2[:, 1:], sa_quantile_hats2[:, -1:]], dim=1)
        assert values_2.shape == signs_2.shape

        gradient_of_taus2 = (
            torch.where(signs_1, values_1, -values_1)
            + torch.where(signs_2, values_2, -values_2)
        ).view(batch_size, N-1)
        assert not gradient_of_taus2.requires_grad
        assert gradient_of_taus2.shape == taus2[:, 1:-1].shape

        fraction_loss = gradient_of_taus1 * taus1[:, 1:-1] + gradient_of_taus2 * taus2[:, 1:-1]
        fraction_loss = (fraction_loss / 2.).sum(dim=1).mean()

        return fraction_loss

    def save_models(self, save_dir):
        super().save_models(save_dir)
        self._policy_net.save(os.path.join(save_dir, 'policy_net.pth'))
        self._online_q_net.save(os.path.join(save_dir, 'online_q_net.pth'))
        self._target_q_net.save(os.path.join(save_dir, 'target_q_net.pth'))
