# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# from typing import List, Tuple, Optional

# # ==============================================================================
# # Helper Functions
# # ==============================================================================

# def get_tau(batch_size: int, num_quantiles: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
#     """
#     Sample taus for Implicit Quantile Network (IQN) style.
#     """
#     presum_tau = torch.rand(batch_size, num_quantiles, device=device) + 0.1
#     presum_tau /= presum_tau.sum(dim=-1, keepdims=True)
    
#     tau = torch.cumsum(presum_tau, dim=1)
#     with torch.no_grad():
#         tau_hat = torch.zeros_like(tau)
#         tau_hat[:, 0:1] = tau[:, 0:1] / 2.
#         tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.
        
#     return tau, tau_hat, presum_tau

# def qr_loss(q: torch.Tensor, target_q: torch.Tensor, tau: torch.Tensor, 
#             presum_tau: torch.Tensor, sigma: float = 1.0, reduction: str = "mean") -> torch.Tensor:
#     """
#     Robust Quantile Regression Loss (Huber-Quantile).
#     """
#     if target_q.dim() == 2:
#         target_q = target_q.unsqueeze(0)
    
#     # inputs: [num_q, batch, num_quantiles, 1]
#     inputs = q.unsqueeze(-1)
#     # targets: [1, batch, 1, num_quantiles]
#     targets = target_q.unsqueeze(-2)
#     # tau: [num_q, batch, num_quantiles, 1]
#     tau = tau.unsqueeze(-1)

#     # Huber Loss
#     u = inputs - targets
#     beta = 1. / (sigma ** 2)
#     diff = torch.abs(u)
#     cond = diff < beta
#     huber_loss = torch.where(cond, 0.5 * u ** 2 / beta, diff - 0.5 * beta)

#     # Quantile Loss (Soft sign)
#     sign = torch.sign(u) / 2. + 0.5
#     loss = torch.abs(tau - sign) * huber_loss
    
#     loss = loss.sum(dim=-1)
    
#     if reduction == "none":
#         return loss.sum(dim=-1) # Sum over ensemble
#     return loss.mean()

# # ==============================================================================
# # Vectorized Linear Layer (Key for Parameters)
# # ==============================================================================
# class VectorizedLinear(nn.Module):
#     def __init__(self, in_features: int, out_features: int, ensemble_size: int):
#         super().__init__()
#         self.in_features = in_features
#         self.out_features = out_features
#         self.ensemble_size = ensemble_size
        
#         # 这里的 nn.Parameter 必须被正确初始化，模型才会有参数
#         self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
#         self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
        
#         self.reset_parameters()

#     def reset_parameters(self):
#         # Kaiming Uniform Init
#         for i in range(self.ensemble_size):
#             nn.init.kaiming_uniform_(self.weight[i], a=np.sqrt(5))
#             fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
#             bound = 1 / np.sqrt(fan_in) if fan_in > 0 else 0
#             nn.init.uniform_(self.bias[i], -bound, bound)
            
#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # Support broadcasting for non-ensemble inputs
#         if x.dim() == 2:
#             x = x.unsqueeze(0).repeat(self.ensemble_size, 1, 1)
        
#         original_shape = x.shape
#         # Flatten batch dimensions for baddbmm
#         x_flat = x.view(self.ensemble_size, -1, self.in_features)
        
#         out_flat = torch.baddbmm(self.bias, x_flat, self.weight)
        
#         out_shape = list(original_shape[:-1]) + [self.out_features]
#         return out_flat.view(*out_shape)

# # ==============================================================================
# # 1. TRACER Critic (Input: s, a, r)
# # ==============================================================================

# class VectorizedCritic(nn.Module):
#     def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256,
#                  num_q: int = 2, num_quantiles: int = 32, cosines_dim: int = 64):
#         super().__init__()
#         self.state_dim = state_dim
#         # action_dim passed here should include reward dimension if needed (e.g., act_dim + 1)
#         self.input_dim = state_dim + action_dim 
#         self.num_q = num_q
#         self.num_quantiles = num_quantiles
        
#         # Trunk
#         self.q_trunk = nn.Sequential(
#             VectorizedLinear(self.input_dim, hidden_dim, num_q),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, hidden_dim, num_q),
#             nn.ReLU()
#         )
        
#         # Cosine Embedding
#         self.q_tau = nn.Sequential(
#             VectorizedLinear(cosines_dim, hidden_dim, num_q),
#             nn.ReLU()
#         )

#         # Head
#         self.q_val = nn.Sequential(
#             VectorizedLinear(hidden_dim, hidden_dim, num_q),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, hidden_dim, num_q),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, 1, num_q)
#         )
        
#         self.register_buffer('range_pi', torch.arange(1, cosines_dim + 1).float() * np.pi)

#     def forward(self, inputs: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
#         """
#         Args:
#             inputs: [batch, state_dim + action_dim + reward_dim]
#             tau: [num_q, batch, num_quantiles]
#         """
#         # 1. Embed State-Action-Reward
#         q_feat = self.q_trunk(inputs) # [num_q, batch, hidden]
        
#         # 2. Embed Tau
#         cosines = torch.cos(tau.unsqueeze(-1) * self.range_pi)
#         tau_feat = self.q_tau(cosines) # [num_q, batch, num_quantiles, hidden]
        
#         # 3. Combine
#         q_feat = q_feat.unsqueeze(2)
#         combined = q_feat * tau_feat 
        
#         # 4. Final MLP
#         out = self.q_val(combined)
        
#         return out.squeeze(-1) # [num_q, batch, num_quantiles]

# # ==============================================================================
# # 2. Distributional Value Function
# # ==============================================================================

# class DistributionalValueFunction(nn.Module):
#     def __init__(self, state_dim: int, hidden_dim: int = 256, num_v: int = 1, cosines_dim: int = 64):
#         super().__init__()
#         self.num_v = num_v
        
#         self.v_trunk = nn.Sequential(
#             VectorizedLinear(state_dim, hidden_dim, num_v),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, hidden_dim, num_v),
#             nn.ReLU()
#         )
        
#         self.v_tau = nn.Sequential(
#             VectorizedLinear(cosines_dim, hidden_dim, num_v),
#             nn.ReLU()
#         )
        
#         self.v_val = nn.Sequential(
#             VectorizedLinear(hidden_dim, hidden_dim, num_v),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, hidden_dim, num_v),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, 1, num_v)
#         )

#         self.register_buffer('range_pi', torch.arange(1, cosines_dim + 1).float() * np.pi)

#     def forward(self, state: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
#         if tau.dim() == 2:
#             tau = tau.unsqueeze(0).repeat(self.num_v, 1, 1)

#         v_feat = self.v_trunk(state)
        
#         cosines = torch.cos(tau.unsqueeze(-1) * self.range_pi)
#         tau_feat = self.v_tau(cosines)
        
#         v_feat = v_feat.unsqueeze(2)
#         combined = v_feat * tau_feat
        
#         out = self.v_val(combined)
#         return out.squeeze(-1)

# # ==============================================================================
# # 3. Observation Model (Masked VAE) - 确保此类定义正确
# # ==============================================================================

# class ObservationModel(nn.Module):
#     def __init__(self, state_dim: int, action_dim: int, reward_dim: int = 1,
#                  num_quantiles: int = 32, hidden_dim: int = 256,
#                  num_model: int = 3, device='cpu', sigma=0.3):
#         super().__init__()
#         self.device = device
#         self.state_dim = state_dim
#         self.action_dim = action_dim
#         self.reward_dim = reward_dim
#         self.num_model = num_model 
#         self.sigma = sigma
        
#         input_dim = state_dim * 2 + action_dim + reward_dim
#         output_dim = state_dim + action_dim + reward_dim
        
#         # Input includes quantile conditioning
#         net_input_dim = input_dim + num_quantiles

#         # 这里的 VectorizedLinear 会注册参数到 self.trunk
#         self.trunk = nn.Sequential(
#             VectorizedLinear(net_input_dim, hidden_dim, num_model),
#             nn.ReLU(),
#             VectorizedLinear(hidden_dim, hidden_dim, num_model),
#             nn.ReLU()
#         )
        
#         # 这里的 VectorizedLinear 会注册参数到 self.mean_head 和 self.log_std_head
#         self.mean_head = VectorizedLinear(hidden_dim, output_dim, num_model)
#         self.log_std_head = VectorizedLinear(hidden_dim, output_dim, num_model)
        
#         self._init_mask_matrix(state_dim, action_dim, reward_dim, input_dim, output_dim)

#     def _init_mask_matrix(self, state_dim, action_dim, reward_dim, input_dim, output_dim):
#         s_end = state_dim
#         a_end = state_dim + action_dim
#         s2_end = state_dim * 2 + action_dim
#         r_end = input_dim
        
#         # 1. Input Masks (0 means blocked/masked)
#         self.input_mask = torch.ones(self.num_model, input_dim)
#         self.input_mask[0, 0:s_end] = 0 # Mask S
#         self.input_mask[1, s_end:a_end] = 0 # Mask A
#         self.input_mask[2, s2_end:r_end] = 0 # Mask R (Original code masks R here)
        
#         # 2. Output Masks (1 means target to predict)
#         self.output_mask = torch.zeros(self.num_model, output_dim)
#         self.output_mask[0, 0:state_dim] = 1 # Predict S
#         self.output_mask[1, state_dim:state_dim+action_dim] = 1 # Predict A
#         self.output_mask[2, state_dim+action_dim:] = 1 # Predict R
        
#     def to(self, device):
#         super().to(device)
#         self.input_mask = self.input_mask.to(device)
#         self.output_mask = self.output_mask.to(device)
#         return self

#     def forward(self, state, action, reward, next_state, d):
#         flat_inputs = torch.cat([state, action, next_state, reward], dim=-1)
#         inputs = flat_inputs.unsqueeze(0).repeat(self.num_model, 1, 1)
        
#         masked_inputs = inputs * self.input_mask.unsqueeze(1)
        
#         d_ens = d.unsqueeze(0).repeat(self.num_model, 1, 1)
#         net_in = torch.cat([masked_inputs, d_ens], dim=-1)
        
#         h = self.trunk(net_in)
#         mean = self.mean_head(h)
#         log_std = self.log_std_head(h).clamp(-5, 2)
        
#         mean = mean * self.output_mask.unsqueeze(1)
#         log_std = log_std * self.output_mask.unsqueeze(1)
        
#         return mean, log_std

#     def get_loss(self, state, action, reward, next_state, d, sigma=0.3):
#         mean, log_std = self.forward(state, action, reward, next_state, d)
        
#         target = torch.cat([state, action, reward], dim=-1)
#         target = target.unsqueeze(0).repeat(self.num_model, 1, 1)
#         target = target * self.output_mask.unsqueeze(1)
        
#         inv_std = torch.exp(-log_std)
#         mse_loss = 0.5 * ((mean - target) * inv_std).pow(2) 
#         log_std_loss = log_std 
        
#         loss_element = (mse_loss + log_std_loss) * self.output_mask.unsqueeze(1)
#         loss = loss_element.sum(dim=-1).mean()
        
#         # Calculate unweighted reconstruction error for reliability weight
#         with torch.no_grad():
#             raw_mse = (mean - target).pow(2) * self.output_mask.unsqueeze(1)
#             # Sum over dimension, keep ensemble dim (model 0,1,2) separated initially?
#             # TRACER sums error over all models for a data point to measure total "surprise"
#             recon_error = raw_mse.sum(dim=-1).sum(dim=0, keepdim=True).T # [batch, 1]
            
#         return loss, recon_error

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Optional

# ==============================================================================
# Helper Functions
# ==============================================================================

def get_tau(batch_size: int, num_quantiles: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Sample taus for Implicit Quantile Network (IQN) style."""
    presum_tau = torch.rand(batch_size, num_quantiles, device=device) + 0.1
    presum_tau /= presum_tau.sum(dim=-1, keepdims=True)
    
    tau = torch.cumsum(presum_tau, dim=1)
    with torch.no_grad():
        tau_hat = torch.zeros_like(tau)
        tau_hat[:, 0:1] = tau[:, 0:1] / 2.
        tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.
        
    return tau, tau_hat, presum_tau

def huber_loss(diff, sigma=1.0):
    """Huber loss helper."""
    beta = 1. / (sigma ** 2)
    diff = torch.abs(diff)
    cond = diff < beta
    loss = torch.where(cond, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
    return loss

def qr_loss(q: torch.Tensor, target_q: torch.Tensor, tau: torch.Tensor, 
            presum_tau: torch.Tensor, sigma: float = 1.0, reduction: str = "mean") -> torch.Tensor:
    """Robust Quantile Regression Loss (Huber-Quantile)."""
    if target_q.dim() == 2:
        target_q = target_q.unsqueeze(0)
    
    inputs = q.unsqueeze(-1)
    targets = target_q.unsqueeze(-2)
    tau = tau.unsqueeze(-1)

    # Huber Loss
    u = inputs - targets
    loss = huber_loss(u, sigma=sigma)

    # Quantile Loss (Soft sign)
    sign = torch.sign(u) / 2. + 0.5
    loss = torch.abs(tau - sign) * loss
    
    loss = loss.sum(dim=-1)
    
    if reduction == "none":
        return loss.sum(dim=-1)
    return loss.mean()

# ==============================================================================
# Vectorized Linear Layer
# ==============================================================================
class VectorizedLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, ensemble_size: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size
        
        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        # Kaiming Uniform Init as default
        for i in range(self.ensemble_size):
            nn.init.kaiming_uniform_(self.weight[i], a=np.sqrt(5))
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
            bound = 1 / np.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias[i], -bound, bound)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:
            x = x.unsqueeze(0).repeat(self.ensemble_size, 1, 1)
        original_shape = x.shape
        x_flat = x.view(self.ensemble_size, -1, self.in_features)
        out_flat = torch.baddbmm(self.bias, x_flat, self.weight)
        out_shape = list(original_shape[:-1]) + [self.out_features]
        return out_flat.view(*out_shape)

# ==============================================================================
# 1. TRACER Critic
# ==============================================================================
class VectorizedCritic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256,
                 num_q: int = 2, num_quantiles: int = 32, cosines_dim: int = 64):
        super().__init__()
        self.state_dim = state_dim
        self.input_dim = state_dim + action_dim 
        self.num_q = num_q
        self.num_quantiles = num_quantiles
        
        self.q_trunk = nn.Sequential(
            VectorizedLinear(self.input_dim, hidden_dim, num_q),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_q),
            nn.ReLU()
        )
        self.q_tau = nn.Sequential(
            VectorizedLinear(cosines_dim, hidden_dim, num_q),
            nn.ReLU()
        )
        # Split Head for separate initialization control
        self.q_head_hidden = nn.Sequential(
            VectorizedLinear(hidden_dim, hidden_dim, num_q),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_q),
            nn.ReLU()
        )
        self.q_head_out = VectorizedLinear(hidden_dim, 1, num_q)
        
        self.register_buffer('range_pi', torch.arange(1, cosines_dim + 1).float() * np.pi)
        self.edac_init()

    def edac_init(self):
        # Apply special initialization for the output layer (Critical for EDAC/TRACER)
        for i in range(self.num_q):
            nn.init.constant_(self.q_head_out.bias[i], 0.1)
            nn.init.uniform_(self.q_head_out.weight[i], -3e-3, 3e-3)

    def forward(self, inputs: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
        q_feat = self.q_trunk(inputs)
        cosines = torch.cos(tau.unsqueeze(-1) * self.range_pi)
        tau_feat = self.q_tau(cosines)
        
        q_feat = q_feat.unsqueeze(2)
        combined = q_feat * tau_feat 
        
        hidden = self.q_head_hidden(combined)
        out = self.q_head_out(hidden)
        
        return out.squeeze(-1)

# ==============================================================================
# 2. Distributional Value Function
# ==============================================================================
class DistributionalValueFunction(nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int = 256, num_v: int = 1, cosines_dim: int = 64):
        super().__init__()
        self.num_v = num_v
        
        self.v_trunk = nn.Sequential(
            VectorizedLinear(state_dim, hidden_dim, num_v),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_v),
            nn.ReLU()
        )
        self.v_tau = nn.Sequential(
            VectorizedLinear(cosines_dim, hidden_dim, num_v),
            nn.ReLU()
        )
        self.v_head_hidden = nn.Sequential(
            VectorizedLinear(hidden_dim, hidden_dim, num_v),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_v),
            nn.ReLU()
        )
        self.v_head_out = VectorizedLinear(hidden_dim, 1, num_v)

        self.register_buffer('range_pi', torch.arange(1, cosines_dim + 1).float() * np.pi)
        self.edac_init()

    def edac_init(self):
        for i in range(self.num_v):
            nn.init.constant_(self.v_head_out.bias[i], 0.1)
            nn.init.uniform_(self.v_head_out.weight[i], -3e-3, 3e-3)

    def forward(self, state: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
        if tau.dim() == 2:
            tau = tau.unsqueeze(0).repeat(self.num_v, 1, 1)

        v_feat = self.v_trunk(state)
        cosines = torch.cos(tau.unsqueeze(-1) * self.range_pi)
        tau_feat = self.v_tau(cosines)
        
        v_feat = v_feat.unsqueeze(2)
        combined = v_feat * tau_feat
        
        hidden = self.v_head_hidden(combined)
        out = self.v_head_out(hidden)
        return out.squeeze(-1)

# ==============================================================================
# 3. Observation Model (Masked VAE)
# ==============================================================================
class ObservationModel(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, reward_dim: int = 1,
                 num_quantiles: int = 32, hidden_dim: int = 256,
                 num_model: int = 3, device='cpu', sigma=0.3):
        super().__init__()
        self.device = device
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.num_model = num_model 
        self.sigma = sigma
        
        input_dim = state_dim * 2 + action_dim + reward_dim
        output_dim = state_dim + action_dim + reward_dim
        
        net_input_dim = input_dim + num_quantiles

        self.trunk = nn.Sequential(
            VectorizedLinear(net_input_dim, hidden_dim, num_model),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_model),
            nn.ReLU()
        )
        
        self.mean_head = VectorizedLinear(hidden_dim, output_dim, num_model)
        self.log_std_head = VectorizedLinear(hidden_dim, output_dim, num_model)
        
        self._init_mask_matrix(state_dim, action_dim, reward_dim, input_dim, output_dim)

    def _init_mask_matrix(self, state_dim, action_dim, reward_dim, input_dim, output_dim):
        s_end = state_dim
        a_end = state_dim + action_dim
        s2_end = state_dim * 2 + action_dim
        r_end = input_dim
        
        self.input_mask = torch.ones(self.num_model, input_dim)
        self.input_mask[0, 0:s_end] = 0 
        self.input_mask[1, s_end:a_end] = 0 
        self.input_mask[2, s2_end:r_end] = 0 
        
        self.output_mask = torch.zeros(self.num_model, output_dim)
        self.output_mask[0, 0:state_dim] = 1 
        self.output_mask[1, state_dim:state_dim+action_dim] = 1 
        self.output_mask[2, state_dim+action_dim:] = 1 
        
    def to(self, device):
        super().to(device)
        self.input_mask = self.input_mask.to(device)
        self.output_mask = self.output_mask.to(device)
        return self

    def forward(self, state, action, reward, next_state, d):
        flat_inputs = torch.cat([state, action, next_state, reward], dim=-1)
        inputs = flat_inputs.unsqueeze(0).repeat(self.num_model, 1, 1)
        
        masked_inputs = inputs * self.input_mask.unsqueeze(1)
        d_ens = d.unsqueeze(0).repeat(self.num_model, 1, 1)
        net_in = torch.cat([masked_inputs, d_ens], dim=-1)
        
        h = self.trunk(net_in)
        mean = self.mean_head(h)
        log_std = self.log_std_head(h).clamp(-5, 2)
        
        mean = mean * self.output_mask.unsqueeze(1)
        log_std = log_std * self.output_mask.unsqueeze(1)
        return mean, log_std

    def get_loss(self, state, action, reward, next_state, d, sigma=0.3):
        mean, log_std = self.forward(state, action, reward, next_state, d)
        
        target = torch.cat([state, action, reward], dim=-1)
        target = target.unsqueeze(0).repeat(self.num_model, 1, 1)
        target = target * self.output_mask.unsqueeze(1)
        
        # Original TRACER uses Huber Loss on the standardized residual
        inv_std = torch.exp(-log_std)
        # Huber loss implementation matching original bayes_modules.py
        # loss = huber((mean - label) / std) + log_std_penalty
        huber = huber_loss((mean - target) * inv_std, sigma=sigma) 
        
        loss_element = (huber + log_std) * self.output_mask.unsqueeze(1)
        loss = loss_element.sum(dim=-1).mean()
        
        # Reliability weight: Simple unweighted MSE
        with torch.no_grad():
            raw_mse = (mean - target).pow(2) * self.output_mask.unsqueeze(1)
            recon_error = raw_mse.sum(dim=-1).sum(dim=0, keepdim=True).T 
            
        return loss, recon_error