from typing import Any, Dict, List, Optional, Tuple

from copy import deepcopy
import hydra
import numpy as np
import torch
import torch.nn.functional as F

from mtrl.agent import sac
from mtrl.agent import utils as agent_utils
from mtrl.agent.components.decoder import make_decoder
from mtrl.agent.ds.mt_obs import MTObs
from mtrl.agent.ds.task_info import TaskInfo
from mtrl.agent.optimizer import PCGrad
from mtrl.logger import Logger
from mtrl.replay_buffer import ReplayBuffer, ReplayBufferSample
from mtrl.env.types import ObsType
from mtrl.utils.types import ConfigType, ModelType, ParameterType, TensorType
from mtrl.utils.utils import mask_extreme_task
from mtrl.utils.utils import get_parameters_num


class Agent(sac.Agent):
    """SAC algorithm + Hybrid MTRL"""
    def __init__(
        self,
        env_obs_shape: List[int],
        action_shape: List[int],
        action_range: Tuple[int, int],
        device: torch.device,
        actor_cfg: ConfigType,
        critic_cfg: ConfigType,
        alpha_optimizer_cfg: ConfigType,
        actor_optimizer_cfg: ConfigType,
        critic_optimizer_cfg: ConfigType,
        multitask_cfg: ConfigType,
        discount: float,
        init_temperature: float,
        actor_update_freq: int,
        critic_tau: float,
        critic_target_update_freq: int,
        encoder_tau: float,
        loss_reduction: str = "mean",
        cfg_to_load_model: Optional[ConfigType] = None,
        should_complete_init: bool = True,
        logger: Logger = None,
    ):
        super().__init__(
            env_obs_shape,
            action_shape,
            action_range,
            device,
            actor_cfg,
            critic_cfg,
            alpha_optimizer_cfg,
            actor_optimizer_cfg,
            critic_optimizer_cfg,
            multitask_cfg,
            discount,
            init_temperature,
            actor_update_freq,
            critic_tau,
            critic_target_update_freq,
            encoder_tau,
            loss_reduction,
            cfg_to_load_model,
            should_complete_init,
            logger,
        ) 
        self.il_task_flag = torch.tensor(
            [False for _ in range(self.num_envs)]
        ).to(self.device)

    def update_critic(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ) -> None:
        """Update the critic component.

        Args:
            batch (ReplayBufferSample): batch from the replay buffer.
            task_info (TaskInfo): task_info object.
            logger ([Logger]): logger object.
            step (int): step for tracking the training of the agent.
            kwargs_to_compute_gradient (Dict[str, Any]):

        """
        with torch.no_grad():
            target_V = self._get_target_V(batch=batch, task_info=task_info)
            target_Q = batch.reward + (batch.not_done * self.discount * target_V)

        # get current Q estimates
        mtobs = MTObs(env_obs=batch.env_obs, task_obs=batch.task_obs, task_info=task_info)
        current_Q1, current_Q2 = self.critic(
            mtobs=mtobs,
            action=batch.action,
            detach_encoder=False,
        )
        assert current_Q1.shape == target_Q.shape

        critic_loss_q1 = F.mse_loss(current_Q1, target_Q, reduce=False)
        critic_loss_q2 = F.mse_loss(current_Q2, target_Q, reduce=False)
        critic_loss = critic_loss_q1 + critic_loss_q2
        if step < self.multitask_cfg.mask_loss_step:
            use_loss_threshold = False
        else:
            use_loss_threshold = self.multitask_cfg.use_loss_threshold
        critic_loss, task_critic_loss, critic_mask = mask_extreme_task(critic_loss, batch.task_obs,
            self.multitask_cfg.num_envs, use_loss_threshold, self.multitask_cfg.mask_loss_threshold)

        il_task_flag_float = self.il_task_flag.to(critic_loss.dtype)
        rl_task_flag_float = 1 - il_task_flag_float
        critic_loss_i = critic_loss.reshape(self.multitask_cfg.num_envs, -1).mean(-1)
        assert rl_task_flag_float.shape == critic_loss_i.shape
        critic_loss_i = critic_loss_i * rl_task_flag_float
        if self.loss_reduction == "alpha_weight":
            assert self.loss_alpha_weight.shape == critic_loss_i.shape
            critic_loss_i = self.loss_alpha_weight * critic_loss_i
        critic_loss = critic_loss_i.sum() / (rl_task_flag_float.sum() + 1e-9)
        logger.log("train/critic_loss", critic_loss, step)

        component_names = ["critic"]
        parameters: List[ParameterType] = []
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name)
        if task_info.compute_grad:
            component_names.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(
            loss=critic_loss,
            task_loss_list=critic_loss_i,
            parameters=parameters,
            step=step,
            component_names=component_names,
            **kwargs_to_compute_gradient,
        )

        # Optimize the critic
        torch.nn.utils.clip_grad_norm_(parameters, self.multitask_cfg.clip_grad_norm)
        self.critic_optimizer.step()

        return critic_mask
    
    def update_actor_and_alpha(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
        critic_mask: TensorType = None,
    ) -> None:
        """Update the actor and alpha component.

        Args:
            batch (ReplayBufferSample): batch from the replay buffer.
            task_info (TaskInfo): task_info object.
            logger ([Logger]): logger object.
            step (int): step for tracking the training of the agent.
            kwargs_to_compute_gradient (Dict[str, Any]):

        """

        # detach encoder, so we don't update it with the actor loss
        mtobs = MTObs(
            env_obs=batch.env_obs,
            task_obs=batch.task_obs,
            task_info=task_info,
        )
        _, pi, log_pi, log_std = self.actor(mtobs=mtobs, detach_encoder=True, route_explore=True)
        critic_task_info = self.get_task_info(task_info.encoding, compute_grad=False, env_index=task_info.env_index)
        critic_mtobs = MTObs(
            env_obs=batch.env_obs,
            task_obs=batch.task_obs,
            task_info=critic_task_info,
        )
        actor_Q1, actor_Q2 = self.critic(mtobs=critic_mtobs, action=pi, detach_encoder=True, route_explore=False)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.get_alpha(batch.task_obs).detach() * log_pi - actor_Q)

        if step < self.multitask_cfg.mask_loss_step:
            use_loss_threshold = False
        else:
            use_loss_threshold = self.multitask_cfg.use_loss_threshold
        actor_loss, actor_task_loss, actor_mask = mask_extreme_task(actor_loss, batch.task_obs,
            self.multitask_cfg.num_envs, use_loss_threshold, self.multitask_cfg.mask_loss_threshold, mask = critic_mask)
        
        il_task_flag_float = self.il_task_flag.to(actor_loss.dtype)
        rl_task_flag_float = 1 - il_task_flag_float
        actor_loss_i = actor_loss.reshape(self.multitask_cfg.num_envs, -1).mean(-1)
        assert rl_task_flag_float.shape == actor_loss_i.shape
        actor_loss_i = actor_loss_i * rl_task_flag_float
        if self.loss_reduction == "alpha_weight":
            assert self.loss_alpha_weight.shape == actor_loss_i.shape
            actor_loss_i = self.loss_alpha_weight * actor_loss_i
        actor_loss_mean = actor_loss_i.sum() / (rl_task_flag_float.sum() + 1e-9)

        imitation_loss = torch.mean((pi-batch.action)**2, -1)
        imitation_loss_i = imitation_loss.reshape(self.multitask_cfg.num_envs, -1).mean(-1)
        assert il_task_flag_float.shape == imitation_loss_i.shape
        imitation_loss_i = imitation_loss_i * il_task_flag_float
        imitation_loss_mean = imitation_loss_i.sum() / (il_task_flag_float.sum() + 1e-9)

        hybrid_actor_loss_i = actor_loss_i + imitation_loss_i
        hybrid_actor_loss = hybrid_actor_loss_i.mean()
        
        logger.log("train/actor_loss", actor_loss_mean, step)
        logger.log("train/imitation_loss", imitation_loss_mean, step)
        logger.log("train/hybrid_actor_loss", hybrid_actor_loss, step)

        logger.log("train/actor_target_entropy", self.target_entropy, step)
        entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
        logger.log("train/actor_entropy", entropy.mean(), step)

        # optimize the actor
        component_names = ["actor"]
        parameters: List[ParameterType] = []
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name)
        if task_info.compute_grad:
            component_names.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(
            loss=hybrid_actor_loss,
            task_loss_list=hybrid_actor_loss_i,
            parameters=parameters,
            step=step,
            component_names=component_names,
            **kwargs_to_compute_gradient,
        )
        torch.nn.utils.clip_grad_norm_(parameters, self.multitask_cfg.clip_grad_norm)
        self.actor_optimizer.step()

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.get_alpha(batch.task_obs) * (-log_pi - self.target_entropy).detach())
        alpha_loss = alpha_loss.reshape(self.multitask_cfg.num_envs, -1).mean(-1)
        assert rl_task_flag_float.shape == alpha_loss.shape
        alpha_loss = (alpha_loss * rl_task_flag_float).sum() / rl_task_flag_float.sum()
        logger.log("train/alpha_loss", alpha_loss, step)
        logger.log("train/imitation_ratio", il_task_flag_float.mean(), step)

        self._compute_gradient(
            loss=alpha_loss,
            task_loss_list=None,
            parameters=self.get_parameters(name="log_alpha"),
            step=step,
            component_names=["log_alpha"],
            alpha_flag=True,
            **kwargs_to_compute_gradient,
        )
        torch.nn.utils.clip_grad_norm_(self.get_parameters(name="log_alpha"), self.multitask_cfg.clip_grad_norm)
        self.log_alpha_optimizer.step()

        if "max_alpha" in self.multitask_cfg:
            max_log_alpha_data = torch.min(self.log_alpha, torch.ones_like(self.log_alpha) * np.log(self.multitask_cfg.max_alpha))
            self.log_alpha.data.copy_(max_log_alpha_data.data)

    def _cal_loss_alpha_weight(self):
        logits = -self.log_alpha.detach()
        rl_mask = ~self.il_task_flag

        if rl_mask.sum() == 0:
            self.loss_alpha_weight = torch.zeros_like(logits)
            return

        neg_inf = torch.tensor(-1e9, device=logits.device, dtype=logits.dtype)
        masked_logits = logits.clone()
        masked_logits[~rl_mask] = neg_inf

        softmax_temp = F.softmax(masked_logits)
        softmax_temp = softmax_temp * rl_mask.to(softmax_temp.dtype)

        self.loss_alpha_weight = softmax_temp * rl_mask.to(softmax_temp.dtype).sum()
        
    def update(
        self,
        replay_buffer: ReplayBuffer,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Optional[Dict[str, Any]] = None,
        buffer_index_to_sample: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        
        task_success_rate = replay_buffer.get_task_success_rate()
        begin_il = task_success_rate.squeeze(-1) > self.multitask_cfg.hybrid_threshold
        self.il_task_flag = torch.logical_or(self.il_task_flag, torch.tensor(begin_il).to(self.device))
        
        super().update(
            replay_buffer = replay_buffer,
            logger = logger,
            step = step,
            kwargs_to_compute_gradient = kwargs_to_compute_gradient,
            buffer_index_to_sample = buffer_index_to_sample,
        )
