# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import copy
# from typing import Dict, Union, Tuple
# from offlinerlkit.policy import BasePolicy
# from offlinerlkit.modules.tracer_module import (
#     VectorizedCritic, 
#     DistributionalValueFunction, 
#     ObservationModel, 
#     qr_loss, 
#     get_tau
# )
# from offlinerlkit.utils.scaler import StandardScaler

# class TRACERPolicy(BasePolicy):
#     def __init__(
#         self,
#         actor: nn.Module,
#         critic: VectorizedCritic,
#         value_net: DistributionalValueFunction,
#         obser_model: ObservationModel,
#         actor_optim: torch.optim.Optimizer,
#         critic_optim: torch.optim.Optimizer,
#         value_optim: torch.optim.Optimizer,
#         obser_optim: torch.optim.Optimizer,
#         tau: float = 0.005,
#         gamma: float = 0.99,
#         beta: float = 3.0,
#         quantile: float = 0.25, 
#         iql_tau: float = 0.7,   
#         obser_sigma: float = 0.3, 
#         num_quantiles: int = 32,
#         num_q: int = 2,
#         device: str = "cpu",
#         scaler: StandardScaler = None,  # Fix 1: Accept Scaler
#         enable_entropy: bool = False,
#         lower_bound: float = 0.9,
#         upper_bound: float = 1.0,
#         obser_beta_start: float = 0.0001,
#         obser_beta_end: float = 0.01,
#         total_steps: int = 1000000
#     ) -> None:
#         super().__init__()
        
#         self.actor = actor
#         self.critic = critic
#         self.critic_target = copy.deepcopy(self.critic)
#         self.value_net = value_net
#         self.obser_model = obser_model
        
#         self.actor_optim = actor_optim
#         self.critic_optim = critic_optim
#         self.value_optim = value_optim
#         self.obser_optim = obser_optim
        
#         self._tau = tau
#         self._gamma = gamma
#         self.device = device
#         self.scaler = scaler # Store scaler
        
#         self.beta = beta
#         self.quantile = quantile
#         self.iql_tau = iql_tau
#         self.obser_sigma = obser_sigma
#         self.num_quantiles = num_quantiles
#         self.num_q = num_q
#         self.clip_score = 100.0

#         self.enable_entropy = enable_entropy
#         self.lower_bound = lower_bound
#         self.upper_bound = upper_bound
        
#         self.obser_beta_start = obser_beta_start
#         self.obser_beta_end = obser_beta_end
#         self.total_steps = total_steps
#         self.current_step = 0

#     def train(self) -> None:
#         self.actor.train()
#         self.critic.train()
#         self.value_net.train()
#         self.obser_model.train()

#     def eval(self) -> None:
#         self.actor.eval()
#         self.critic.eval()
#         self.value_net.eval()
#         self.obser_model.eval()

#     def _sync_weight(self) -> None:
#         for o, n in zip(self.critic_target.parameters(), self.critic.parameters()):
#             o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

#     def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
#         # Fix 1: Normalize during inference
#         if self.scaler is not None:
#             obs = self.scaler.transform(obs)
            
#         with torch.no_grad():
#             obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0)
#             dist = self.actor(obs)
#             if deterministic:
#                 action, _ = dist.mode()
#             else:
#                 action, _ = dist.rsample()
#         return action.cpu().numpy()[0]

#     def d_entropy(self, d_vals: torch.Tensor, presum_tau: torch.Tensor):
#         size1, size2 = d_vals.size(0), d_vals.size(1)
#         d = d_vals.view(-1, self.num_quantiles)
#         pdf = presum_tau.view(-1, self.num_quantiles)
        
#         sorted_d, indices = torch.sort(d, dim=1)
#         sorted_pdf = torch.gather(pdf, 1, indices)
        
#         sorted_d_diff = sorted_d[:, 1:] - sorted_d[:, :-1]
#         mid_pdf = (sorted_pdf[:, 1:] + sorted_pdf[:, :-1]) / 2
        
#         log_mid_pdf = torch.log(mid_pdf + 1e-10)
#         entropy = -torch.sum(sorted_d_diff * mid_pdf * log_mid_pdf, dim=1)
#         entropy = entropy.view(size1, size2, 1)
        
#         weight = torch.exp(- entropy / (sorted_d.mean().abs() + 1e-6))
        
#         # Batch-wise Min-Max Normalization
#         w_min = weight.min(dim=0, keepdim=True)[0].min(dim=1, keepdim=True)[0]
#         w_max = weight.max(dim=0, keepdim=True)[0].max(dim=1, keepdim=True)[0]
#         norm_weight = (weight - w_min) / (w_max - w_min + 1e-6) * (self.upper_bound - self.lower_bound) + self.lower_bound
        
#         return entropy, norm_weight

#     def learn(self, batch: Dict) -> Dict:
#         self.current_step += 1
#         obss, actions, next_obss, rewards, terminals = \
#             batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"]
        
#         # Fix 1: Dynamic Normalization of States during training
#         if self.scaler is not None:
#             obss = self.scaler.transform_tensor(obss)
#             next_obss = self.scaler.transform_tensor(next_obss)

#         batch_size = obss.shape[0]

#         # ==================== 1. Update Critic (Quantile Regression) ====================
#         # Fix 2: Critic Input must be (s, a, r)
#         critic_input = torch.cat([obss, actions, rewards], dim=-1)
        
#         with torch.no_grad():
#             tau_hat_next, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
#             z_pi_targ = self.value_net(next_obss, tau_hat_next).mean(dim=0)
#             d_targ = rewards + (1.0 - terminals) * self._gamma * z_pi_targ
#             d_targ = torch.clamp(d_targ, -100, 1000)

#         tau, tau_hat, presum_tau = get_tau(batch_size, self.num_quantiles, self.device)
#         tau_hat_in = tau_hat.unsqueeze(0).repeat(self.num_q, 1, 1)
        
#         d_vals = self.critic(critic_input, tau_hat_in)
        
#         loss_q = qr_loss(d_vals, d_targ, tau_hat_in, presum_tau, sigma=self.obser_sigma)

#         self.critic_optim.zero_grad()
#         loss_q.backward()
#         self.critic_optim.step()

#         # ==================== 2. Update Value Net (IQL) ====================
#         with torch.no_grad():
#             target_tau_hat, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
#             target_tau_in = target_tau_hat.unsqueeze(0).repeat(self.num_q, 1, 1)
#             target_d_all = self.critic_target(critic_input, target_tau_in)
#             target_d = torch.quantile(target_d_all, self.quantile, dim=0)

#         tau_hat_v, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
#         z_vals = self.value_net(obss, tau_hat_v)
        
#         adv = target_d - z_vals
#         loss_v = torch.mean(torch.abs(self.iql_tau - (adv < 0).float()) * adv.pow(2))

#         self.value_optim.zero_grad()
#         loss_v.backward()
#         self.value_optim.step()

#         # ==================== 3. Update Actor (AWR) ====================
#         adv_mean = adv.mean(dim=-1).detach()
#         exp_adv = torch.exp(self.beta * adv_mean).clamp(max=self.clip_score)
        
#         dist = self.actor(obss)
#         if hasattr(self.actor, 'log_prob'):
#              loss_bc = -dist.log_prob(actions)
#         else:
#              loss_bc = -dist.log_prob(actions).sum(-1)
        
#         loss_pi = torch.mean(exp_adv * loss_bc)

#         self.actor_optim.zero_grad()
#         loss_pi.backward()
#         self.actor_optim.step()

#         # ==================== 4. Update Observation Model (VAE & Consistency) ====================
#         with torch.no_grad():
#             tau_map, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
#             tau_map_in = tau_map.unsqueeze(0).repeat(self.num_q, 1, 1)
#             d_vals_map = self.critic(critic_input, tau_map_in)
#             d_cond = torch.quantile(d_vals_map, self.quantile, dim=0)

#         loss_recon_data, _ = self.obser_model.get_loss(obss, actions, rewards, next_obss, d_cond, self.obser_sigma)

#         # === Consistency Check (Fix 3: 3-way check) ===
#         # VAE forward to get reconstructed distributions
#         mean_recon, log_std_recon = self.obser_model(obss, actions, rewards, next_obss, d_cond.detach())
#         std_recon = torch.exp(log_std_recon)
#         eps = torch.randn_like(mean_recon)
#         samples = mean_recon + std_recon * eps
        
#         # Dimensions for slicing
#         s_dim = obss.shape[1]
#         a_dim = actions.shape[1]
        
#         q_s = samples[0, :, :s_dim]            # Model 0: Predict S
#         q_a = samples[1, :, s_dim:s_dim+a_dim] # Model 1: Predict A
#         q_r = samples[2, :, s_dim+a_dim:]      # Model 2: Predict R

#         # Construct 3 data batches
#         # Data 1: (q_s, a, r)
#         data1 = torch.cat([q_s, actions, rewards], dim=-1)
#         # Data 2: (s, q_a, r)
#         data2 = torch.cat([obss, q_a, rewards], dim=-1)
#         # Data 3: (s, a, q_r) - Was missing in previous code
#         data3 = torch.cat([obss, actions, q_r], dim=-1)
        
#         d_in = torch.cat([data1, data2, data3], dim=0) # [batch*3, total_dim]

#         tau_cons, tau_hat_cons, presum_tau_cons = get_tau(batch_size * 3, self.num_quantiles, self.device)
#         tau_hat_cons_in = tau_hat_cons.unsqueeze(0).repeat(self.num_q, 1, 1)
#         presum_tau_cons_in = presum_tau_cons.unsqueeze(0).repeat(self.num_q, 1, 1)
        
#         # Check Consistency via Critic
#         d_sample = self.critic(d_in, tau_hat_cons_in)
#         d_target_cons = d_vals.detach().repeat(1, 3, 1)
        
#         with torch.no_grad():
#             entropy, norm_weight = self.d_entropy(d_sample, presum_tau_cons_in)
#             d_weight = norm_weight if self.enable_entropy else 1.0

#         loss_d_elem = qr_loss(d_sample, d_target_cons, tau_hat_cons_in, presum_tau_cons, sigma=self.obser_sigma, reduction='none')
#         loss_recon_d = (loss_d_elem * d_weight).mean()

#         # Beta Scheduling
#         slope = (self.obser_beta_end - self.obser_beta_start) / self.total_steps
#         current_beta = self.obser_beta_start + slope * self.current_step
#         current_beta = min(max(current_beta, self.obser_beta_start), self.obser_beta_end)

#         loss_map = loss_recon_data + current_beta * loss_recon_d
        
#         self.obser_optim.zero_grad()
#         self.critic_optim.zero_grad()
#         loss_map.backward()
#         self.obser_optim.step()
#         self.critic_optim.step()

#         self._sync_weight()

#         return {
#             "loss/actor": loss_pi.item(),
#             "loss/critic": loss_q.item(),
#             "loss/value": loss_v.item(),
#             "loss/recon_data": loss_recon_data.item(),
#             "loss/recon_d": loss_recon_d.item(),
#             "metrics/adv_mean": adv_mean.mean().item(),
#             "metrics/beta": current_beta
#         }

import numpy as np
import torch
import torch.nn as nn
import copy
from typing import Dict
from offlinerlkit.policy import BasePolicy
from offlinerlkit.modules.tracer_module import (
    VectorizedCritic, 
    DistributionalValueFunction, 
    ObservationModel, 
    qr_loss, 
    get_tau
)
from offlinerlkit.utils.scaler import StandardScaler

class TRACERPolicy(BasePolicy):
    def __init__(
        self,
        actor: nn.Module,
        critic: VectorizedCritic,
        value_net: DistributionalValueFunction,
        obser_model: ObservationModel,
        actor_optim: torch.optim.Optimizer,
        critic_optim: torch.optim.Optimizer,
        value_optim: torch.optim.Optimizer,
        obser_optim: torch.optim.Optimizer,
        tau: float = 0.005,
        gamma: float = 0.99,
        beta: float = 3.0,
        quantile: float = 0.25, 
        iql_tau: float = 0.7,   
        sigma: float = 0.1,         # New: Sigma for QR loss
        obser_sigma: float = 0.1,   # New: Sigma for VAE
        num_quantiles: int = 32,
        num_q: int = 2,
        device: str = "cpu",
        scaler: StandardScaler = None,
        enable_entropy: bool = False,
        lower_bound: float = 0.9,
        upper_bound: float = 1.0,
        obser_beta_start: float = 0.0001,
        obser_beta_end: float = 0.01,
        total_steps: int = 1000000
    ) -> None:
        super().__init__()
        
        self.actor = actor
        self.critic = critic
        self.critic_target = copy.deepcopy(self.critic)
        self.value_net = value_net
        self.obser_model = obser_model
        
        self.actor_optim = actor_optim
        self.critic_optim = critic_optim
        self.value_optim = value_optim
        self.obser_optim = obser_optim
        
        self._tau = tau
        self._gamma = gamma
        self.device = device
        self.scaler = scaler
        
        self.beta = beta
        self.quantile = quantile
        self.iql_tau = iql_tau
        self.sigma = sigma              # Store Correct QR Sigma
        self.obser_sigma = obser_sigma  # Store Correct VAE Sigma
        self.num_quantiles = num_quantiles
        self.num_q = num_q
        self.clip_score = 100.0

        self.enable_entropy = enable_entropy
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        
        self.obser_beta_start = obser_beta_start
        self.obser_beta_end = obser_beta_end
        self.total_steps = total_steps
        self.current_step = 0

    def train(self) -> None:
        self.actor.train()
        self.critic.train()
        self.value_net.train()
        self.obser_model.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critic.eval()
        self.value_net.eval()
        self.obser_model.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.critic_target.parameters(), self.critic.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        if self.scaler is not None:
            obs = self.scaler.transform(obs)
            
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0)
            dist = self.actor(obs)
            if deterministic:
                action, _ = dist.mode()
            else:
                action, _ = dist.rsample()
        return action.cpu().numpy()[0]

    def d_entropy(self, d_vals: torch.Tensor, presum_tau: torch.Tensor):
        size1, size2 = d_vals.size(0), d_vals.size(1)
        d = d_vals.view(-1, self.num_quantiles)
        pdf = presum_tau.view(-1, self.num_quantiles)
        
        sorted_d, indices = torch.sort(d, dim=1)
        sorted_pdf = torch.gather(pdf, 1, indices)
        
        sorted_d_diff = sorted_d[:, 1:] - sorted_d[:, :-1]
        mid_pdf = (sorted_pdf[:, 1:] + sorted_pdf[:, :-1]) / 2
        
        log_mid_pdf = torch.log(mid_pdf + 1e-10)
        entropy = -torch.sum(sorted_d_diff * mid_pdf * log_mid_pdf, dim=1)
        entropy = entropy.view(size1, size2, 1)
        
        # Original code used sorted_d.mean() without abs(). 
        # Although odd if Q < 0, we must match original implementation behavior.
        weight = torch.exp(- entropy / (sorted_d.mean() + 1e-6))
        
        w_min = weight.min(dim=0, keepdim=True)[0].min(dim=1, keepdim=True)[0]
        w_max = weight.max(dim=0, keepdim=True)[0].max(dim=1, keepdim=True)[0]
        norm_weight = (weight - w_min) / (w_max - w_min + 1e-6) * (self.upper_bound - self.lower_bound) + self.lower_bound
        
        return entropy, norm_weight

    def learn(self, batch: Dict) -> Dict:
        self.current_step += 1
        obss, actions, next_obss, rewards, terminals = \
            batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"]
        
        if self.scaler is not None:
            obss = self.scaler.transform_tensor(obss)
            next_obss = self.scaler.transform_tensor(next_obss)

        batch_size = obss.shape[0]

        # 1. Update Critic
        # Input to critic is [s, a, r]
        critic_input = torch.cat([obss, actions, rewards], dim=-1)
        
        with torch.no_grad():
            tau_hat_next, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
            # Target Q calculation
            z_pi_targ = self.value_net(next_obss, tau_hat_next).mean(dim=0)
            d_targ = rewards + (1.0 - terminals) * self._gamma * z_pi_targ
            d_targ = torch.clamp(d_targ, -100, 1000)

        tau, tau_hat, presum_tau = get_tau(batch_size, self.num_quantiles, self.device)
        tau_hat_in = tau_hat.unsqueeze(0).repeat(self.num_q, 1, 1)
        
        d_vals = self.critic(critic_input, tau_hat_in)
        # Use self.sigma for QR loss
        loss_q = qr_loss(d_vals, d_targ, tau_hat_in, presum_tau, sigma=self.sigma)

        self.critic_optim.zero_grad()
        loss_q.backward()
        self.critic_optim.step()

        # 2. Update Value Net
        with torch.no_grad():
            target_tau_hat, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
            target_tau_in = target_tau_hat.unsqueeze(0).repeat(self.num_q, 1, 1)
            target_d_all = self.critic_target(critic_input, target_tau_in)
            target_d = torch.quantile(target_d_all, self.quantile, dim=0)

        tau_hat_v, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
        z_vals = self.value_net(obss, tau_hat_v)
        
        adv = target_d - z_vals
        loss_v = torch.mean(torch.abs(self.iql_tau - (adv < 0).float()) * adv.pow(2))

        self.value_optim.zero_grad()
        loss_v.backward()
        self.value_optim.step()

        # 3. Update Actor
        adv_mean = adv.mean(dim=-1).detach()
        exp_adv = torch.exp(self.beta * adv_mean).clamp(max=self.clip_score)
        
        dist = self.actor(obss)
        if hasattr(self.actor, 'log_prob'):
             loss_bc = -dist.log_prob(actions)
        else:
             loss_bc = -dist.log_prob(actions).sum(-1)
        
        loss_pi = torch.mean(exp_adv * loss_bc)

        self.actor_optim.zero_grad()
        loss_pi.backward()
        self.actor_optim.step()

        # 4. Update Observation Model (MAP)
        with torch.no_grad():
            tau_map, _, _ = get_tau(batch_size, self.num_quantiles, self.device)
            tau_map_in = tau_map.unsqueeze(0).repeat(self.num_q, 1, 1)
            d_vals_map = self.critic(critic_input, tau_map_in)
            d_cond = torch.quantile(d_vals_map, self.quantile, dim=0)

        # Use self.obser_sigma for VAE loss
        loss_recon_data, _ = self.obser_model.get_loss(obss, actions, rewards, next_obss, d_cond, self.obser_sigma)

        # Consistency Check
        mean_recon, log_std_recon = self.obser_model(obss, actions, rewards, next_obss, d_cond.detach())
        std_recon = torch.exp(log_std_recon)
        eps = torch.randn_like(mean_recon)
        samples = mean_recon + std_recon * eps
        
        s_dim = obss.shape[1]
        a_dim = actions.shape[1]
        
        q_s = samples[0, :, :s_dim]            
        q_a = samples[1, :, s_dim:s_dim+a_dim] 
        q_r = samples[2, :, s_dim+a_dim:]      

        data1 = torch.cat([q_s, actions, rewards], dim=-1)
        data2 = torch.cat([obss, q_a, rewards], dim=-1)
        data3 = torch.cat([obss, actions, q_r], dim=-1)
        
        d_in = torch.cat([data1, data2, data3], dim=0)

        tau_cons, tau_hat_cons, presum_tau_cons = get_tau(batch_size * 3, self.num_quantiles, self.device)
        tau_hat_cons_in = tau_hat_cons.unsqueeze(0).repeat(self.num_q, 1, 1)
        presum_tau_cons_in = presum_tau_cons.unsqueeze(0).repeat(self.num_q, 1, 1)
        
        d_sample = self.critic(d_in, tau_hat_cons_in)
        d_target_cons = d_vals.detach().repeat(1, 3, 1)
        
        with torch.no_grad():
            entropy, norm_weight = self.d_entropy(d_sample, presum_tau_cons_in)
            d_weight = norm_weight if self.enable_entropy else 1.0

        # Use self.obser_sigma for consistency QR loss
        loss_d_elem = qr_loss(d_sample, d_target_cons, tau_hat_cons_in, presum_tau_cons, sigma=self.obser_sigma, reduction='none')
        loss_recon_d = (loss_d_elem * d_weight).mean()

        slope = (self.obser_beta_end - self.obser_beta_start) / self.total_steps
        current_beta = self.obser_beta_start + slope * self.current_step
        current_beta = min(max(current_beta, self.obser_beta_start), self.obser_beta_end)

        loss_map = loss_recon_data + current_beta * loss_recon_d
        
        self.obser_optim.zero_grad()
        self.critic_optim.zero_grad()
        loss_map.backward()
        self.obser_optim.step()
        self.critic_optim.step()

        self._sync_weight()

        return {
            "loss/actor": loss_pi.item(),
            "loss/critic": loss_q.item(),
            "loss/value": loss_v.item(),
            "loss/recon_data": loss_recon_data.item(),
            "loss/recon_d": loss_recon_d.item(),
            "metrics/adv_mean": adv_mean.mean().item(),
            "metrics/beta": current_beta
        }