import torch
import torch.nn.functional as F
import copy

def bellman_error(value_net_s, value_net_sp, states, rewards, next_states, gamma):
    """
    Compute the Bellman error: V_{theta'}(s) - (r + gamma * V_theta(s')).
    
    value_net_s: The value network for the current state (theta' or theta).
    value_net_sp: The value network for the next state (theta, always fixed).
    states: Current states (torch.Tensor).
    rewards: Rewards obtained (torch.Tensor).
    next_states: Next states after transition (torch.Tensor).
    gamma: Discount factor (float).
    
    Returns: Bellman error for the current theta' and fixed theta.
    """
    # Compute V_{theta'}(s)
    V_s = value_net_s(states).view(-1)
    
    # Compute V_{theta}(s') with detached gradients
    V_sp = value_net_sp(next_states).detach().view(-1)
    
    # Bellman target: r + gamma * V_{theta}(s')
    bellman_target = rewards + gamma * V_sp
    
    # Bellman error: V_{theta'}(s) - (r + gamma * V_theta(s'))
    return V_s - bellman_target

def evaluate_bellman_gap(value_net, traj, gamma, optimizer_class, lr, num_steps):
    """
    Compute the Bellman error gap: BE(theta, theta) - min_{theta'} BE(theta', theta).
    
    value_net: The original value network (theta) to evaluate.
    states: Current states (torch.Tensor).
    rewards: Rewards obtained (torch.Tensor).
    next_states: Next states after transition (torch.Tensor).
    gamma: Discount factor (float).
    optimizer_class: Optimizer class (e.g., torch.optim.Adam).
    lr: Learning rate for optimizing the copied network.
    num_steps: Number of optimization steps for theta'.
    
    Returns: The Bellman error gap.
    """

    # Unpack the test trajectories into separate lists
    states, actions, rewards, next_states, values = zip(*traj)

    # Concatenate lists into tensors
    states = torch.stack(states)
    actions = torch.stack(actions)
    rewards = torch.stack(rewards)
    next_states = torch.stack(next_states)
    values = torch.stack(values)
    # Compute Bellman error for the original network (theta, theta)
    bellman_error_theta = torch.mean(bellman_error(value_net, value_net, states, rewards, next_states, gamma)**2).item()

    # Create a copy of the network for optimization (theta')
    value_net_prime = copy.deepcopy(value_net)
    optimizer = optimizer_class(value_net_prime.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
    
    # Perform optimization to minimize BE(theta', theta)
    for _ in range(num_steps):
        optimizer.zero_grad()
        bellman_error_prime = bellman_error(value_net_prime, value_net, states, rewards, next_states, gamma)
        loss = torch.mean(bellman_error_prime**2)
        loss.backward()
        optimizer.step()
        scheduler.step()

    # Compute final minimized Bellman error for theta'
    min_bellman_error_prime = loss.item()

    # Compute and return the gap
    return bellman_error_theta - min_bellman_error_prime

# Example usage:
# gap = evaluate_bellman_gap(value_net, states, rewards, next_states, gamma, torch.optim.Adam, lr=0.001, num_steps=50)
# print(f"Bellman Error Gap: {gap}")
