# algorithms/cvar_utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class QuantileCritic(nn.Module):
    """Critic network that outputs a set of quantile values for the return distribution."""
    def __init__(self, state_dim, quantiles=50, hidden_sizes=(256,256)):
        super().__init__()
        self.quantiles = quantiles
        layers = []
        last_dim = state_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.ReLU())
            last_dim = h
        # output `quantiles` number of values
        layers.append(nn.Linear(last_dim, quantiles))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        # output shape: (batch, quantiles)
        return self.net(x)

def cvar_from_quantiles(quantile_values, alpha):
    """
    Compute CVaR from quantile outputs.
    `quantile_values`: Tensor of shape (batch, N) with sorted quantiles (low to high).
    Returns: Tensor of shape (batch,) with CVaR_alpha estimates.
    """
    # Determine index corresponding to alpha-quantile
    N = quantile_values.size(1)
    k = max(1, int(alpha * N))
    # Average of lowest k quantile estimates (assuming quantile_values sorted ascending)
    cvar_vals = quantile_values[:, :k].mean(dim=1)
    return cvar_vals

@torch.no_grad()
def get_next_values(critic, trajectories, device, alpha):
    """Compute ρ̂_α[V(x_{i+1})] for all transitions in the batch."""
    # Gather next states from trajectories
    next_states = []
    for traj in trajectories:
        # For each step in traj, if not done, include next_state
        states = traj['states']; dones = traj['dones']
        for idx, d in enumerate(dones):
            if not d and idx < len(states)-1:
                # next state is states[idx+1]
                next_states.append(states[idx+1])
    if len(next_states) == 0:
        return torch.tensor([], dtype=torch.float32, device=device)
    next_states_t = torch.tensor(np.array(next_states), dtype=torch.float32, device=device)
    if isinstance(critic, QuantileCritic):
        # Distributional case: get quantiles for next states and compute CVaR
        quantiles = torch.sort(critic(next_states_t), dim=1)[0]  # sort along quantile dimension
        cvar_vals = cvar_from_quantiles(quantiles, alpha)
    else:
        # Single critic case: approximate CVaR by empirical distribution of values
        values_next = critic(next_states_t)
        # Sort all next-state values and take average of worst alpha-fraction
        values_sorted, _ = torch.sort(values_next)
        k = max(1, int(alpha * len(values_sorted)))
        cvar_vals = values_sorted[:k].mean()  # this yields a single scalar CVaR for all combined next states
        # Expand to match each occurrence (approximation: treat all as having same CVaR of tail)
        cvar_vals = cvar_vals.expand(len(next_states))
    # Now we need to map these next-state CVaR values back to each transition in original batch.
    # assume each transition shares the same tail risk as the aggregate.
    return cvar_vals
