#!/usr/bin/env python3
"""
Script to calculate replay buffer memory usage for SAC with Robosuite
"""

import numpy as np
import robosuite as suite
from robosuite.wrappers import GymWrapper
import gym

def calculate_buffer_memory(env_id="Lift-Panda", buffer_size=int(1e6)):
    """
    Calculate the memory footprint of the replay buffer
    
    Args:
        env_id: Environment ID (e.g., "Lift-Panda")
        buffer_size: Size of replay buffer
    
    Returns:
        Memory usage in GB
    """
    
    # Create environment to get observation and action space sizes
    env_name, robot = env_id.split('-')
    env = suite.make(
        env_name=env_name,
        robots=robot,
        has_renderer=False,
        has_offscreen_renderer=False,
        use_object_obs=True,
        use_camera_obs=False,
        reward_shaping=True,
    )
    env = GymWrapper(env)
    
    # Get space dimensions
    obs_dim = np.array(env.observation_space.shape).prod()
    action_dim = np.array(env.action_space.shape).prod()
    
    print(f"Environment: {env_id}")
    print(f"  Obs dims: {obs_dim}, Action dims: {action_dim}")
    
    # Calculate memory for each component stored in replay buffer
    # ReplayBuffer typically stores: observations, next_observations, actions, rewards, dones
    
    # Assuming float32 (4 bytes) for all data
    bytes_per_float32 = 4
    
    # Memory components:
    obs_memory = buffer_size * obs_dim * bytes_per_float32  # observations
    next_obs_memory = buffer_size * obs_dim * bytes_per_float32  # next_observations
    action_memory = buffer_size * action_dim * bytes_per_float32  # actions
    reward_memory = buffer_size * 1 * bytes_per_float32  # rewards (scalar)
    done_memory = buffer_size * 1 * 1  # dones (boolean, 1 byte)
    
    # Additional metadata (indices, etc.) - small overhead
    metadata_memory = buffer_size * 4  # rough estimate for indices and other metadata
    
    total_memory_bytes = (obs_memory + next_obs_memory + action_memory + 
                         reward_memory + done_memory + metadata_memory)
    
    # Convert to different units
    total_memory_mb = total_memory_bytes / (1024 * 1024)
    total_memory_gb = total_memory_bytes / (1024 * 1024 * 1024)
    
    print(f"  Total replay buffer memory: {total_memory_gb:.3f} GB")
    
    env.close()
    return total_memory_gb

if __name__ == "__main__":
    # List of all tasks to check
    tasks = [
        "NutAssemblySingle-Panda", "NutAssemblySingle-Sawyer",
        "NutAssemblySquare-Panda", "NutAssemblySquare-Sawyer", 
        "PickPlaceSingle-Panda", "PickPlaceSingle-Sawyer",
        "PickPlaceMilk-Panda", "PickPlaceMilk-Sawyer",
        "PickPlaceBread-Panda", "PickPlaceBread-Sawyer",
        "PickPlaceCereal-Panda", "PickPlaceCereal-Sawyer",
        "PickPlaceCan-Panda", "PickPlaceCan-Sawyer"
    ]
    
    buffer_size = int(1e6)  # Your specified buffer size
    print(f"REPLAY BUFFER MEMORY CALCULATION")
    print(f"Buffer size: {buffer_size:,} samples")
    print("="*80)
    
    memory_results = {}
    
    for task in tasks:
        try:
            print(f"\n{task}:")
            memory_gb = calculate_buffer_memory(task, buffer_size)
            memory_results[task] = memory_gb
            print()
        except Exception as e:
            print(f"Error with {task}: {e}")
            memory_results[task] = None
    
    # Summary
    print(f"\n{'='*80}")
    print("SUMMARY - MEMORY REQUIREMENTS PER TASK (Single Process)")
    print("="*80)
    
    max_memory = 0
    min_memory = float('inf')
    
    for task, memory_gb in memory_results.items():
        if memory_gb is not None:
            print(f"{task:<25}: {memory_gb:.3f} GB")
            max_memory = max(max_memory, memory_gb)
            min_memory = min(min_memory, memory_gb)
    
    print(f"\nRange: {min_memory:.3f} GB - {max_memory:.3f} GB")
    
    print(f"\n{'='*80}")
    print("RECOMMENDATIONS FOR PARALLEL TRAINING:")
    print("="*80)
    print(f"Single process (any task): Request {max_memory * 1.5:.1f} GB RAM")
    print(f"2 parallel processes: Request {max_memory * 2 * 1.3:.1f} GB RAM") 
    print(f"4 parallel processes: Request {max_memory * 4 * 1.3:.1f} GB RAM")
    print(f"8 parallel processes: Request {max_memory * 8 * 1.3:.1f} GB RAM")
    print(f"16 parallel processes: Request {max_memory * 16 * 1.3:.1f} GB RAM")
    print("\n(Includes 30-50% overhead for models, gradients, and OS)")
    
    # Memory per core recommendations
    print(f"\nMEMORY PER CPU CORE:")
    print(f"Safe estimate: {max_memory * 1.5:.1f} GB per process")
    print(f"If running N parallel processes, request N × {max_memory * 1.5:.1f} GB")
