from collections import defaultdict
from typing import List, Optional, Tuple

import numpy as np
import torch as th
from torch.nn import functional as F

from viqs.sb3 import SAC
from viqs.sb3.common.off_policy_algorithm import OffPolicyAlgorithm
from viqs.sb3.common.type_aliases import GymEnv, MaybeCallback
from viqs.sb3.common.utils import polyak_update


class TS2CSAC(SAC):

    def __init__(self, *args, **kwargs):

        super(TS2CSAC, self).__init__(*args, **kwargs)

    def _setup_model(self) -> None:
        super(TS2CSAC, self)._setup_model()

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer, self.critic.optimizer]
        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        # ent_coef_losses, ent_coefs, entropys = [], [], []
        # actor_losses, critic_losses = [], []

        stat_recorder = defaultdict(list)

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            # Action by the current actor for the sampled state
            actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = th.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                # ent_coef_losses.append(ent_coef_loss.item())
                stat_recorder["ent_coef_loss"].append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            stat_recorder["entropy"].append(-log_prob.mean().item())
            stat_recorder["ent_coef"].append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            with th.no_grad():
                # Select action according to policy
                next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
                # Compute the next Q values: min over all critics targets
                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
                next_q_values_, _ = th.min(next_q_values, dim=1, keepdim=True)
                # add entropy term
                next_q_values = next_q_values_ - ent_coef * next_log_prob.reshape(-1, 1)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
                # ts2c error
                target_q_values += -2 * replay_data.next_intervention_start.view(-1, 1) * (1 - replay_data.dones)

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(replay_data.observations, replay_data.actions_behavior)
            current_q_values_novice = self.critic(replay_data.observations, replay_data.actions_novice)
            for i, v in enumerate(current_q_values):
                stat_recorder["q_value_{}".format(i)].append(v.mean().item())

            # Compute critic loss
            # critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
            # critic_loss += replay_data.interventions * (current_q_values_novice-current_q_values)
            critic_loss = []

            for current_q_behavior, current_q_novice in zip(current_q_values, current_q_values_novice):
                l = F.mse_loss(current_q_behavior, target_q_values)
                # haco_l = th.mean(replay_data.interventions * (current_q_novice - current_q_behavior))
                # l += haco_l
                critic_loss.append(l)
            critic_loss = sum(critic_loss)

            # critic_losses.append(critic_loss.item())
            stat_recorder["critic_loss"].append(critic_loss.item())

            # Optimize the critic
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
            # Mean over all critic networks
            q_values_pi = th.cat(self.critic.forward(replay_data.observations, actions_pi), dim=1)
            min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
            actor_loss = (ent_coef * log_prob - min_qf_pi).mean()

            # a_n_norm = F.normalize(replay_data.actions_novice, dim=1)
            # a_h_norm = F.normalize(replay_data.actions_behavior, dim=1)
            # cos_sim = th.sum(a_n_norm * a_h_norm, dim=1, keepdim=True)  # shape: (batch_size, 1)
            #
            # # Cost = 1 - cos_sim
            # cost = 1.0 - cos_sim  # shape: (batch_size, 1)
            # haco_actor_loss = next_q_values_ * self.gamma + cost
            # actor_loss += th.mean(haco_actor_loss)

            stat_recorder["q_value_min"].append(min_qf_pi.mean().item())

            # actor_losses.append(actor_loss.item())
            stat_recorder["actor_loss"].append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates)

        for key, values in stat_recorder.items():
            self.logger.record("train/{}".format(key), np.mean(values))
        # self.logger.record("train/actor_loss", np.mean(actor_losses))
        # self.logger.record("train/critic_loss", np.mean(critic_losses))
        # self.logger.record("train/entropy", np.mean(entropys))
        # if len(ent_coef_losses) > 0:
        #     self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "SAC",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> OffPolicyAlgorithm:

        return super(SAC, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )

    def _excluded_save_params(self) -> List[str]:
        return super(SAC, self)._excluded_save_params() + [
            "actor", "critic", "actor_target", "critic_target", "human_data_buffer", "replay_buffer"
        ]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
        if self.ent_coef_optimizer is not None:
            saved_pytorch_variables = ["log_ent_coef"]
            state_dicts.append("ent_coef_optimizer")
        else:
            saved_pytorch_variables = ["ent_coef_tensor"]
        return state_dicts, saved_pytorch_variables
