import numpy as np
import torch
import torch.nn as nn
from torch.distributions import kl_divergence
from copy import deepcopy
from typing import Dict, Union, Tuple
from offlinerlkit.policy import BasePolicy
from offlinerlkit.utils.scaler import StandardScaler

class RORLPolicy(BasePolicy):
    """
    Robust Offline Reinforcement Learning (RORL) via Conservative Smoothing
    """

    def __init__(
        self,
        actor: nn.Module,
        critics: nn.ModuleList,
        actor_optim: torch.optim.Optimizer,
        critics_optim: torch.optim.Optimizer,
        tau: float = 0.005,
        gamma: float = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
        max_q_backup: bool = False,
        deterministic_backup: bool = False,
        # === RORL: Consistency / Smoothing Params ===
        num_samples: int = 20,
        policy_smooth_eps: float = 0.0,
        policy_smooth_reg: float = 0.0,
        q_smooth_eps: float = 0.0,
        q_smooth_reg: float = 0.0,
        q_smooth_tau: float = 0.2,
        # === RORL: OOD Conservative Params (New) ===
        q_ood_eps: float = 0.0,
        q_ood_reg: float = 0.0,
        q_ood_uncertainty_reg: float = 0.0,
        q_ood_uncertainty_reg_min: float = 0.0,
        # ===========================================
        obs_std: float = 1.0,
        scaler: StandardScaler = None,
        device: str = "cpu"
    ) -> None:

        super().__init__()
        self.actor = actor
        self.critics = critics
        self.critics_old = deepcopy(critics)
        self.critics_old.eval()

        self.actor_optim = actor_optim
        self.critics_optim = critics_optim

        self._tau = tau
        self._gamma = gamma
        self.device = device

        self._is_auto_alpha = False
        if isinstance(alpha, tuple):
            self._is_auto_alpha = True
            self._target_entropy, self._log_alpha, self.alpha_optim = alpha
            self._alpha = self._log_alpha.detach().exp()
        else:
            self._alpha = alpha

        self._max_q_backup = max_q_backup
        self._deterministic_backup = deterministic_backup
        
        # RORL Parameters
        self.num_samples = num_samples
        self.policy_smooth_eps = policy_smooth_eps
        self.policy_smooth_reg = policy_smooth_reg
        self.q_smooth_eps = q_smooth_eps
        self.q_smooth_reg = q_smooth_reg
        self.q_smooth_tau = q_smooth_tau
        
        # OOD Parameters
        self.q_ood_eps = q_ood_eps
        self.q_ood_reg = q_ood_reg
        self.q_ood_uncertainty_reg = q_ood_uncertainty_reg
        self.q_ood_uncertainty_reg_min = q_ood_uncertainty_reg_min
        
        self.obs_std = obs_std
        self.scaler = scaler

    def train(self) -> None:
        self.actor.train()
        self.critics.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critics.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
    
    def actforward(
        self,
        obs: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        dist = self.actor(obs)
        if deterministic:
            squashed_action, raw_action = dist.mode()
        else:
            squashed_action, raw_action = dist.rsample()
        log_prob = dist.log_prob(squashed_action, raw_action)
        return squashed_action, log_prob, dist.mean, dist.stddev

    def select_action(
        self,
        obs: np.ndarray,
        deterministic: bool = False
    ) -> np.ndarray:
        if self.scaler is not None:
            obs = self.scaler.transform(obs)
        
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0)
            action, _, _, _ = self.actforward(obs, deterministic)
        
        return action.cpu().numpy()[0]

    def _get_noised_obs(self, obs, eps):
        # Noise is Uniform[-eps*std, eps*std]
        M, N = obs.shape[0], obs.shape[1]
        size = self.num_samples
        
        delta_s = 2 * eps * self.obs_std * (torch.rand(size, N, device=self.device) - 0.5) 
        
        tmp_obs = obs.reshape(-1, 1, N).repeat(1, size, 1).reshape(-1, N)
        delta_s = delta_s.reshape(1, size, N).repeat(M, 1, 1).reshape(-1, N)
        
        noised_obs = tmp_obs + delta_s
        return M, size, noised_obs

    def learn(self, batch: Dict) -> Dict:
        obss, actions, next_obss, rewards, terminals = \
            batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"]
        
        if self.scaler is not None:
            obss = self.scaler.transform_tensor(obss)
            next_obss = self.scaler.transform_tensor(next_obss)

        batch_size = obss.shape[0]
        action_dim = actions.shape[-1]

        # -------------------------
        # 1. Update Actor
        # -------------------------
        a, log_probs, policy_mean, policy_std = self.actforward(obss)
        qas = self.critics(obss, a)
        actor_loss = -torch.min(qas, 0)[0].mean() + self._alpha * log_probs.mean()

        # === RORL: Policy Smoothing ===
        if self.policy_smooth_eps > 0 and self.policy_smooth_reg > 0:
            M, size, noised_obs = self._get_noised_obs(obss, self.policy_smooth_eps)
            _, _, noised_policy_mean, noised_policy_std = self.actforward(noised_obs)
            
            orig_mean_expanded = policy_mean.reshape(-1, 1, action_dim).repeat(1, size, 1).reshape(-1, action_dim)
            orig_std_expanded = policy_std.reshape(-1, 1, action_dim).repeat(1, size, 1).reshape(-1, action_dim)
            
            action_dist = torch.distributions.Normal(orig_mean_expanded, orig_std_expanded)
            noised_action_dist = torch.distributions.Normal(noised_policy_mean, noised_policy_std)
            
            kl_loss = kl_divergence(action_dist, noised_action_dist).sum(dim=-1) + \
                      kl_divergence(noised_action_dist, action_dist).sum(dim=-1)
            kl_loss = kl_loss.reshape(M, size)
            kl_loss_max = kl_loss.max(dim=1)[0].mean()
            
            actor_loss += self.policy_smooth_reg * kl_loss_max

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0)

        # -------------------------
        # 2. Update Critic
        # -------------------------
        if self._max_q_backup:
            with torch.no_grad():
                tmp_next_obss = next_obss.unsqueeze(1).repeat(1, 10, 1) \
                    .view(batch_size * 10, next_obss.shape[-1])
                tmp_next_actions, _, _, _ = self.actforward(tmp_next_obss)
                tmp_next_qs = self.critics_old(tmp_next_obss, tmp_next_actions) \
                    .view(self.critics._num_ensemble, batch_size, 10, 1).max(2)[0] \
                    .view(self.critics._num_ensemble, batch_size, 1)
                next_q = tmp_next_qs.min(0)[0]
        else:
            with torch.no_grad():
                next_actions, next_log_probs, _, _ = self.actforward(next_obss)
                next_q = self.critics_old(next_obss, next_actions).min(0)[0]
                if not self._deterministic_backup:
                    next_q -= self._alpha * next_log_probs

        target_q = rewards + self._gamma * (1 - terminals) * next_q
        qs = self.critics(obss, actions)
        critics_loss = ((qs - target_q.unsqueeze(0)).pow(2)).mean(dim=(1, 2)).sum()

        # === RORL: Q-function Smoothing (Consistency) ===
        if self.q_smooth_eps > 0 and self.q_smooth_reg > 0:
            M, size, noised_obs = self._get_noised_obs(obss, self.q_smooth_eps)
            actions_repeated = actions.reshape(-1, 1, action_dim).repeat(1, size, 1).reshape(-1, action_dim)
            noised_qs_pred = self.critics(noised_obs, actions_repeated)
            qs_pred_expanded = qs.repeat(1, 1, size).reshape(self.critics._num_ensemble, -1, 1)
            
            diff = noised_qs_pred - qs_pred_expanded
            pos = torch.maximum(diff, torch.zeros_like(diff))
            neg = torch.minimum(diff, torch.zeros_like(diff))
            
            noise_Q_loss = (1 - self.q_smooth_tau) * pos.pow(2) + self.q_smooth_tau * neg.pow(2)
            noise_Q_loss = noise_Q_loss.mean(dim=0).reshape(M, size)
            noise_Q_loss_max = noise_Q_loss.max(dim=1)[0].mean()
            
            critics_loss += self.q_smooth_reg * noise_Q_loss_max

        # === RORL: OOD Conservative Penalty (Value Suppression & Uncertainty) ===
        # Corresponds to q_ood_reg and q_ood_uncertainty_reg in sac.py
        if self.q_ood_eps > 0 and (self.q_ood_reg > 0 or self.q_ood_uncertainty_reg > 0):
            # 1. Sample OOD states (typically with larger epsilon than smoothing, or same)
            M, size, ood_obs = self._get_noised_obs(obss, self.q_ood_eps)
            
            # 2. Sample actions from the current policy on these OOD states
            ood_actions, _, _, _ = self.actforward(ood_obs)
            
            # 3. Calculate Q values for OOD (s', a')
            ood_qs = self.critics(ood_obs, ood_actions) # Shape: [num_critics, M*size, 1]
            
            # 4. Conservative Value Penalty: Minimize mean Q value on OOD states
            if self.q_ood_reg > 0:
                # We typically want to minimize the Q values
                critics_loss += self.q_ood_reg * ood_qs.mean()
            
            # 5. Uncertainty Penalty: Minimize variance/std of Q values on OOD states
            if self.q_ood_uncertainty_reg > 0:
                # Calculate std across the ensemble dimension (dim=0)
                ood_qs_std = ood_qs.std(dim=0) # Shape: [M*size, 1]
                
                # Apply optional hinge loss for uncertainty (std - min)+
                if self.q_ood_uncertainty_reg_min > 0:
                    ood_qs_std = torch.clamp(ood_qs_std - self.q_ood_uncertainty_reg_min, min=0)
                
                critics_loss += self.q_ood_uncertainty_reg * ood_qs_std.mean()

        self.critics_optim.zero_grad()
        critics_loss.backward()
        self.critics_optim.step()

        self._sync_weight()

        result =  {
            "loss/actor": actor_loss.item(),
            "loss/critics": critics_loss.item()
        }

        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()
        
        return result