
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch

import torch
'''
@torch.no_grad
def retrace(
    rewards: torch.Tensor,
    dones: torch.Tensor,
    q_values: torch.Tensor,
    actions: torch.Tensor,
    behavior_policy_probs: torch.Tensor,
    target_policy_probs: torch.Tensor,
    gamma: float = 0.99,
    c: float = 1.0,
):
    """
    Compute Retrace targets.

    Args:
        rewards (torch.Tensor): Tensor of shape (timesteps, envs, buffer_size), containing rewards at each step.
        dones (torch.Tensor): Tensor of shape (timesteps, envs, buffer_size), with 1 if episode ended, 0 otherwise.
        q_values (torch.Tensor): Tensor of shape (timesteps, envs, buffer_size, action_dim), Q-values from the target policy.
        actions (torch.Tensor): Tensor of shape (timesteps, envs, buffer_size), actions taken.
        behavior_policy_probs (torch.Tensor): Tensor of shape (timesteps, envs, buffer_size), behavior policy probabilities.
        target_policy_probs (torch.Tensor): Tensor of shape (timesteps, envs, buffer_size), target policy probabilities.
        gamma (float): Discount factor for rewards (default: 0.99).
        c (float): Clipping factor for importance sampling ratios (default: 1.0).

    Returns:
        torch.Tensor: Retrace targets of shape (timesteps, envs, buffer_size).
    """
    timesteps, envs, buffer_size = rewards.shape
    retrace_targets = torch.zeros_like(rewards)

    # Compute importance sampling ratios
    rho = torch.exp(target_policy_probs - (behavior_policy_probs + 1e-8))  # Avoid division by zero
    truncated_rho = torch.min(rho, torch.tensor(c, device=rho.device))

    # Initialize last target (for bootstrap)
    next_return = q_values[-1]*(1-dones[-1])
    retrace_targets[-1] = next_return

    # Iterate backward through timesteps
    for t in reversed(range(timesteps-1)):
        delta = rewards[t] + gamma*(1 - dones[t])*q_values[t+1] - q_values[t]
        #delta = cumulative[t]*delta
        # Truncated importance sampling ratio
        
        # Retrace target: R + gamma * V(s')
        next_return = delta + gamma * truncated_rho[t] * (1 - dones[t]) * next_return

        # Update next return
        retrace_targets[t] = next_return + q_values[t]
        #next_return = retrace_targets[t] + gamma * truncated_rho * (action_q_values - retrace_targets[t])
    print("retrace_targets",retrace_targets.mean(),retrace_targets.max(),retrace_targets.min())
    return retrace_targets
'''

@torch.no_grad
def retrace(q, rewards, cs, dones, gamma, q_bootstrap, next_dones):
    """
    Compute Retrace targets, handling terminal states.

    Args:
        q (torch.Tensor): Q-values for the actions taken.
                          Shape: [T, B, 1]
        rewards (torch.Tensor): Rewards observed.
                                Shape: [T, B, 1]
        cs (torch.Tensor): Truncation coefficients c_t (typically c_t = min(1, ρ_t)).
                           Shape: [T, B, 1]
        dones (torch.Tensor): Done flags indicating terminal states (1.0 if done, 0.0 otherwise).
                              Shape: [T, B, 1]
        gamma (float): Discount factor.
        q_bootstrap (torch.Tensor): Bootstrap value for the state after the last timestep.
                                    Shape: [B, 1]

    Returns:
        torch.Tensor: Retrace targets of shape [T, B, 1]
    """
    T = rewards.shape[0]
    # Initialize G_ret as the bootstrap value.
    # Note: When the last state is terminal, we should not bootstrap (i.e. q_bootstrap should be masked out).
    G_ret = q_bootstrap
    # List to collect the retrace targets at each time step.
    retrace_targets = [None for _ in range(T)]

    # Loop backwards over time steps.
    for t in reversed(range(T)):
        # Determine the value for the next state.
        # If the current state is terminal (done==1), then we do not bootstrap from the next state.
        # We use (1 - dones[t]) as a mask.
        if t == T - 1:
            V_next = (1.0 - next_dones) * q_bootstrap
            dn = next_dones
        else:
            V_next = (1.0 - dones[t+1]) * q[t + 1]
            dn=dones[t+1]

        # Compute the temporal-difference error.
        # delta_t = r_t + gamma * V(s_{t+1}) - Q(s_t,a_t)
        delta = rewards[t] + gamma * V_next - q[t]

        # When the episode terminates at time t, we should not propagate further rewards,
        # so the bootstrapped future is masked with (1 - dones[t]).
        G_ret = q[t] + delta + gamma * cs[t] * (G_ret - V_next) * (1.0 - dn)

        # Save the computed retrace target for time step t.
        retrace_targets[t] = G_ret

    # Stack the list along the time dimension to form a tensor of shape [T, B, 1]
    return torch.stack(retrace_targets, dim=0)
