import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.backends.backend_agg import FigureCanvasAgg
import torch
from agent import *
from envs import *
import random
import time
from envs import *
from PIL import Image
from datetime import datetime
import os
import numpy as np
from datetime import datetime
from PIL import Image
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def log_allmus(all_mus, act_dim, logger, COLOR="", ENDC=""):
    """
    Log distribution metrics for all action dimensions and their averages.
    
    Args:
        all_mus (torch.Tensor): Tensor of action means
        act_dim (int): Number of action dimensions
        logger: Logger object to use for output
        GREEN (str): ANSI color code for green (default: "")
        ENDC (str): ANSI color code for ending color (default: "")
    """
    # Initialize accumulators for averaging
    avg_diff = 0
    avg_var = 0
    avg_unique = 0
    avg_entropy = 0
    avg_iqr = 0
    avg_kurtosis = 0
    avg_skewness = 0
    
    # Calculate metrics for each dimension
    for i in range(act_dim):
        # Process values for this dimension
        values = torch.tanh(all_mus[0, ..., i]).detach().cpu().numpy().round(2)
        
        values_sorted = np.sort(values)
        logger.debug(f"Dim {i} values: {values_sorted}")
        # Basic spread metrics
        diff = values.max() - values.min()
        var = np.var(values).round(2)
        
        # Additional spread metrics
        unique_count = len(np.unique(values))
        
        # Safely calculate entropy
        hist, _ = np.histogram(values, bins=20, density=True)
        hist = hist[hist > 0]  # Avoid log(0)
        entropy = -np.sum(hist * np.log(hist)) if len(hist) > 0 else 0
        
        iqr = np.percentile(values, 75) - np.percentile(values, 25)
        
        # Safely calculate distribution shape metrics
        variance = ((values - values.mean())**2).mean()
        if variance > 0:
            kurtosis = ((values - values.mean())**4).mean() / (variance**2) - 3
            skewness = ((values - values.mean())**3).mean() / (variance**(3/2))
        else:
            kurtosis = 0
            skewness = 0
        
        # Log individual dimension metrics
        # logger.debug(f"{COLOR}Dim {i}: diff: {diff:.2f} var: {var:.2f} unique: {unique_count} "
        #             f"entropy: {entropy:.2f} iqr: {iqr:.2f} kurt: {kurtosis:.2f} skew: {skewness:.2f}{ENDC}")
        
        # Add to averages
        avg_diff += diff
        avg_var += var
        avg_unique += unique_count
        avg_entropy += entropy
        avg_iqr += iqr
        avg_kurtosis += kurtosis
        avg_skewness += skewness
    
    logger.debug(f"{CYAN} all_mus: {torch.tanh(all_mus[0, ..., :]).detach().cpu().numpy().round(2)} {ENDC}")
    # Calculate averages
    if act_dim > 0:
        avg_diff /= act_dim
        avg_var /= act_dim
        avg_unique /= act_dim
        avg_entropy /= act_dim
        avg_iqr /= act_dim
        avg_kurtosis /= act_dim
        avg_skewness /= act_dim
    
    # Log the averages across all dimensions
    logger.debug(f"{COLOR}AVERAGES: diff: {avg_diff:.2f} var: {avg_var:.2f} unique: {avg_unique:.2f} "
                f"entropy: {avg_entropy:.2f} iqr: {avg_iqr:.2f} kurt: {avg_kurtosis:.2f} skew: {avg_skewness:.2f}{ENDC}")
    
class VideoRecorderwithPraw:
    def __init__(self, output_path='output_video.mp4', fps=30):
        self.output_path = output_path
        self.fps = fps
        self.writer = None
        self.frame_count = 0
        
    def setup_writer(self, frame_shape):
        """Initialize video writer with frame dimensions"""
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        self.writer = cv2.VideoWriter(
            self.output_path, 
            fourcc, 
            self.fps, 
            (frame_shape[1], frame_shape[0])
        )
        
    def create_praw_plot(self, p_raw_values, step_num, active_indices=None):
        """Create a bar plot of p_raw distribution"""
        # Create figure with small size for corner overlay
        fig, ax = plt.subplots(figsize=(3, 2))
        fig.patch.set_facecolor('white')
        fig.patch.set_alpha(0.8)
        
        # Convert tensor to numpy if needed
        if isinstance(p_raw_values, torch.Tensor):
            p_raw_values = p_raw_values.cpu().numpy().flatten()
        
        # Create bar plot with default blue color
        indices = np.arange(len(p_raw_values))
        bars = ax.bar(indices, p_raw_values, color='lightblue', alpha=0.7)
        
        # Color bars based on active indices
        if active_indices is not None:
            # Convert active_indices to list if it's a tensor
            if isinstance(active_indices, torch.Tensor):
                active_indices = active_indices.cpu().numpy().flatten()
            elif not isinstance(active_indices, (list, np.ndarray)):
                active_indices = [active_indices]
            
            # Color the active indices in red
            for idx in active_indices:
                if 0 <= idx < len(bars):
                    bars[idx].set_color('red')
                    bars[idx].set_alpha(0.9)
        
        # ax.set_title(f'Distribution  (Step {step_num})', fontsize=8)
        ax.set_title(r'$w_k(s_t)$ Distribution', fontsize=8)
        ax.set_xlabel('k', fontsize=6)
        ax.set_ylabel('Probability', fontsize=6)
        ax.tick_params(labelsize=6)
        
        # Set y-axis to log scale for better visualization of small values
        ax.set_yscale('log')
        ax.set_ylim(1e-12, 1)
        
        # Tight layout to minimize whitespace
        plt.tight_layout()
        
        # Convert plot to image array
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        # Use the newer buffer_rgba method and convert to RGB
        buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
        plot_img = buf.reshape(canvas.get_width_height()[::-1] + (4,))
        # Convert RGBA to RGB by dropping alpha channel
        plot_img = plot_img[:, :, :3]
        
        plt.close(fig)
        return plot_img
    
    def overlay_plot_on_frame(self, frame, plot_img, position='top_right'):
        """Overlay the plot on the main frame"""
        frame_h, frame_w = frame.shape[:2]
        plot_h, plot_w = plot_img.shape[:2]
        
        # Calculate position
        if position == 'top_right':
            y_offset = 10
            x_offset = frame_w - plot_w - 10
        elif position == 'top_left':
            y_offset = 10
            x_offset = 10
        elif position == 'bottom_right':
            y_offset = frame_h - plot_h - 10
            x_offset = frame_w - plot_w - 10
        elif position == 'bottom_left':
            y_offset = frame_h - plot_h - 10
            x_offset = 10
        else:
            y_offset = 10
            x_offset = frame_w - plot_w - 10
        
        # Ensure plot fits within frame
        y_offset = max(0, min(y_offset, frame_h - plot_h))
        x_offset = max(0, min(x_offset, frame_w - plot_w))
        
        # Overlay plot on frame
        frame[y_offset:y_offset+plot_h, x_offset:x_offset+plot_w] = plot_img
        
        return frame
    
    def add_frame(self, frame, p_raw_values, step_num, active_indices=None):
        """Add a frame with p_raw plot overlay to the video"""
        # Most gym environments with rgb_array mode return RGB format
        # So we don't need to convert from BGR to RGB
        frame_rgb = frame.copy()
        
        # Create plot overlay
        plot_img = self.create_praw_plot(p_raw_values, step_num, active_indices)
        
        # Overlay plot on frame
        combined_frame = self.overlay_plot_on_frame(frame_rgb, plot_img)
        
        # Convert to BGR for video writing (OpenCV expects BGR)
        combined_frame_bgr = cv2.cvtColor(combined_frame, cv2.COLOR_RGB2BGR)
        
        # Initialize writer if not done
        if self.writer is None:
            self.setup_writer(combined_frame_bgr.shape)
        
        # Write frame
        self.writer.write(combined_frame_bgr)
        self.frame_count += 1
        
    def finish(self):
        """Close video writer"""
        if self.writer:
            self.writer.release()
            print(f"Video saved as {self.output_path} with {self.frame_count} frames")

def get_frame_from_env(env):
    """Get the current frame from environment"""
    try:
        # Try to get frame using render
        frame = env.render()
        if frame is not None:
            return frame
    except:
        pass
    
    # Alternative: try to get frame from environment's internal state
    try:
        if hasattr(env, 'get_frame'):
            return env.get_frame()
        elif hasattr(env, 'render_frame'):
            return env.render_frame()
    except:
        pass
    
    # If no frame available, create a placeholder
    return np.zeros((480, 640, 3), dtype=np.uint8)

def record_video_with_praw():
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    print(f"{RED} args: {args} {ENDC}")
    logger.debug(f"args: {args}")
    print(logger)
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode="rgb_array")()  # Use rgb_array for frame capture
    
    # Initialize video recorder
    video_recorder = VideoRecorderwithPraw('agent_praw_video.mp4', fps=24)
    
    with torch.no_grad():
        for episode in range(1):
            o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
            
            while not (d) and ep_len < 1000000:
                # Select component
                k = random.randint(0, args.n_components - 1)
                print(f"Step: {ep_len}, Using component: {k}")
                
                # Get action and info
                a, logp_a, pi, mu, std, std, info = ac.pi(
                    torch.as_tensor(
                        np.expand_dims(o, axis=0), dtype=torch.float32 
                    ).to(device),
                     
                )
                
                print(f"info: {info}") 
                a = a.cpu().numpy()[0]
                
                # Get current frame
                frame = get_frame_from_env(env)
                
                # Extract p_raw and indices from info
                p_raw_values = info.get('p_raw', torch.zeros(64))  # Default to zeros if not found
                active_indices = info.get('indices', None)  # Get indices from info
                
                # Add frame to video with p_raw plot
                video_recorder.add_frame(frame, p_raw_values, ep_len, active_indices)
                
                # Step environment
                o2, r, d, _, info = env.step(a)
                o = o2
                ep_len += 1
                
                # Optional: limit recording duration
                if ep_len > 1000:  # Record first 1000 steps
                    break
                    
    # Finish video recording
    video_recorder.finish()
    print(f"Episode length: {ep_len}")

def Env_random_action_live():
    # Record video with PRAW overlay
    args = read_args()
    args.test = True
    print(f"Running with args: {args}")
    setup_logging(args.log_level) 
    torch.set_num_threads(torch.get_num_threads())

    # Initialize list to store time measurements
    env = build_env(args, render_mode="human")()
    o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
    step_times = []
    
    with torch.no_grad():
        while not (d) and ep_len < 5000:
            start_time = time.time()
            a = env.action_space.sample()
            o2, r, d, _, info = env.step(a)
            end_time = time.time()
            step_time = end_time - start_time
            step_times.append(step_time)
            o = o2
            ep_len += 1
            ep_ret += r
            print(r)
    print(f"Episode length: {ep_len}")
    print(f"Episode return: {ep_ret}")

class SimpleVideoRecorder:
    def __init__(self, output_path='high_quality_humanoid_video.mp4', fps=30):
        self.output_path = output_path
        self.fps = fps
        self.writer = None
        self.frame_count = 0
        
    def setup_writer(self, frame_shape):
        """Initialize video writer with high quality settings"""
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        self.writer = cv2.VideoWriter(
            self.output_path, 
            fourcc, 
            self.fps, 
            (frame_shape[1], frame_shape[0])
        )
    
    def create_info_overlay(self, magnitude, num_actions, episode, step, total_episodes, env_name, env_type="StepSize"):
        """Create information overlay with current settings"""
        fig, ax = plt.subplots(figsize=(4, 1.8))
        fig.patch.set_facecolor('black')
        fig.patch.set_alpha(0.8)
        
        ax.text(0.05, 0.85, f'Episode: {episode + 1}/{total_episodes}', 
                fontsize=12, color='white', fontweight='bold', transform=ax.transAxes)
        ax.text(0.05, 0.7, f'Environment: {env_name}', 
                fontsize=11, color='lightgreen', transform=ax.transAxes)
        ax.text(0.05, 0.55, f'Step: {step}', 
                fontsize=11, color='white', transform=ax.transAxes)
        
        if env_type == "StepSize":
            ax.text(0.05, 0.4, f'Action Magnitude: {magnitude}', 
                    fontsize=11, color='yellow', transform=ax.transAxes)
            ax.text(0.05, 0.25, f'Num Actions: {num_actions}', 
                    fontsize=11, color='cyan', transform=ax.transAxes)
            ax.text(0.05, 0.1, f'Uniform Sampled from ({-magnitude, 0.0, magnitude})', 
                    fontsize=10, color='orange', transform=ax.transAxes)
        else:
            ax.text(0.05, 0.4, f'Num Actions: {num_actions}', 
                    fontsize=11, color='cyan', transform=ax.transAxes)
            ax.text(0.05, 0.25, f'Uniform over action space', 
                    fontsize=10, color='orange', transform=ax.transAxes)
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')
        
        plt.tight_layout()
        
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
        info_img = buf.reshape(canvas.get_width_height()[::-1] + (4,))
        info_img = info_img[:, :, :3]
        
        plt.close(fig)
        return info_img
    
    def overlay_on_frame(self, frame, info_img):
        """Overlay info on the main frame"""
        frame_h, frame_w = frame.shape[:2]
        info_h, info_w = info_img.shape[:2]
        
        # Position info in top left
        info_y = 10
        info_x = 10
        info_y = max(0, min(info_y, frame_h - info_h))
        info_x = max(0, min(info_x, frame_w - info_w))
        frame[info_y:info_y+info_h, info_x:info_x+info_w] = info_img
        
        return frame
    
    def add_frame(self, frame, magnitude, num_actions, episode, step, total_episodes, env_name, env_type):
        """Add a frame with info overlay to the video"""
        frame_rgb = frame.copy()
        
        # Create info overlay
        info_img = self.create_info_overlay(magnitude, num_actions, episode, step, total_episodes, env_name, env_type)
        
        # Overlay on frame
        combined_frame = self.overlay_on_frame(frame_rgb, info_img)
        combined_frame_bgr = cv2.cvtColor(combined_frame, cv2.COLOR_RGB2BGR)
        
        if self.writer is None:
            self.setup_writer(combined_frame_bgr.shape)
        
        self.writer.write(combined_frame_bgr)
        self.frame_count += 1
        
    def finish(self):
        """Close video writer"""
        if self.writer:
            self.writer.release()
            print(f"High quality video saved as {self.output_path} with {self.frame_count} frames")

def create_high_quality_multi_setting_video():
    """
    Create a comprehensive video showing random actions across all environment settings
    """
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    print(f"{RED} Creating high quality multi-setting video with args: {args} {ENDC}")
    
    set_seed(args.seed)
    
    # Define all combinations
    action_magnitudes = [0.05, 0.1, 0.2, 0.3, 0.4]
    num_actions_list = [4, 8, 16, 32, 64, 128]
    
    # Calculate total episodes (Walker2d with magnitude+num_actions, Humanoid with only num_actions)
    walker_episodes = len(action_magnitudes) * len(num_actions_list)
    humanoid_episodes = len(num_actions_list)  # Only num_actions, no magnitude
    total_episodes = walker_episodes + humanoid_episodes
    
    env_name="Humanoid-v5" 
    # Initialize video recorder
    video_recorder = SimpleVideoRecorder(f'{env_name}_comprehensive_analysis_1.mp4', fps=30)
    
    episode_count = 0
    
    # First set of episodes: Walker2d with EnvRandomDiscreteActions
    print(f"\n{'='*60}")
    print(f"STARTING WALKER2D EPISODES WITH EnvRandomDiscreteActions")
    print(f"{'='*60}")
    for magnitude in action_magnitudes:
        for num_actions in num_actions_list:
            print(f"\n{'='*60}")
            print(f"Episode {episode_count + 1}/{total_episodes}")
            print(f"Environment: {env_name} (EnvRandomDiscreteActions)")
            print(f"Action Magnitude: {magnitude}, Num Actions: {num_actions}")
            print(f"{'='*60}")
            
            # Update args for current setting
            args.env_rd_action_magnitude = magnitude
            args.env_rd_num_actions = num_actions
            
            # Create Walker2d environment with EnvRandomDiscreteActions
            env = EnvRandomDiscreteActions(
                gym.make(
                    env_name, 
                    height=1200,
                    width=1200,
                    render_mode="rgb_array"
                ),
                env_rd_action_magnitude=magnitude,
                env_rd_num_actions=num_actions
            )
            
            # Run episode
            o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
            
            while not d:  # Max 50 steps as expected
                # Sample random action
                a = env.action_space.sample()
                
                # Get current frame
                frame = get_frame_from_env(env)
                
                # Add frame to video with info overlay
                video_recorder.add_frame(
                    frame, magnitude, num_actions, 
                    episode_count, ep_len, total_episodes, env_name, "StepSize"
                )
                
                # Step environment
                o2, r, d, _, env_info = env.step(a)
                o = o2
                ep_len += 1
                ep_ret += r
                
                print(f"Step {ep_len}: Reward = {r:.3f}, Done = {d}")
            
            print(f"Episode {episode_count + 1} completed:")
            print(f"  Environment: Walker2d-v5 (EnvRandomDiscreteActions)")
            print(f"  Length: {ep_len} steps")
            print(f"  Return: {ep_ret:.3f}")
            print(f"  Settings: magnitude={magnitude}, num_actions={num_actions}")
            
            # Close environment
            env.close()
            episode_count += 1
    
    # Second set of episodes: Humanoid with UniformActionSetEnv (no magnitude, only num_actions)
    print(f"\n{'='*60}")
    print(f"STARTING HUMANOID EPISODES WITH UniformActionSetEnv")
    print(f"{'='*60}")
    
    for num_actions in num_actions_list:
        print(f"\n{'='*60}")
        print(f"Episode {episode_count + 1}/{total_episodes}")
        print(f"Environment: Humanoid-v5 (UniformActionSetEnv)")
        print(f"Num Actions: {num_actions} (no magnitude parameter)")
        print(f"{'='*60}")
        
        # Update args for current setting
        args.env_rd_num_actions = num_actions
        
        # Create Humanoid environment with UniformActionSetEnv
        env = UniformActionSetEnv(
            gym.make(
                env_name, 
                height=1200,
                width=1200,
                render_mode="rgb_array"
            ),
            num_actions=num_actions
        )
        
        # Run episode
        o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
        
        while not d:  # Max 50 steps as expected
            # Sample random action
            a = env.action_space.sample()
            
            # Get current frame
            frame = get_frame_from_env(env)
            
            # Add frame to video with info overlay (magnitude=None for Humanoid)
            video_recorder.add_frame(
                frame, None, num_actions, 
                episode_count, ep_len, total_episodes, env_name, "Uniform"
            )
            
            # Step environment
            o2, r, d, _, env_info = env.step(a)
            o = o2
            ep_len += 1
            ep_ret += r
            
            print(f"Step {ep_len}: Reward = {r:.3f}, Done = {d}")
        
        print(f"Episode {episode_count + 1} completed:")
        print(f"  Environment: Humanoid-v5 (UniformActionSetEnv)")
        print(f"  Length: {ep_len} steps")
        print(f"  Return: {ep_ret:.3f}")
        print(f"  Settings: num_actions={num_actions} (no magnitude)")
        
        # Close environment
        env.close()
        episode_count += 1
    
    # Finish video recording
    video_recorder.finish()
    
    print(f"\n{'='*60}")
    print(f"COMPREHENSIVE MULTI-ENVIRONMENT VIDEO CREATION COMPLETED!")
    print(f"Total episodes recorded: {episode_count}")
    print(f"  - Walker2d episodes: {len(action_magnitudes) * len(num_actions_list)}")
    print(f"  - Humanoid episodes: {len(action_magnitudes) * len(num_actions_list)}")
    print(f"Video saved as: comprehensive_multi_env_analysis.mp4")
    print(f"{'='*60}")

class VideoRecorder:
    def __init__(self, output_path='output_video.mp4', fps=30):
        self.output_path = output_path
        self.fps = fps
        self.writer = None
        self.frame_count = 0
        
    def setup_writer(self, frame_shape):
        """Initialize video writer with frame dimensions"""
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        self.writer = cv2.VideoWriter(
            self.output_path, 
            fourcc, 
            self.fps, 
            (frame_shape[1], frame_shape[0])
        )
    
    def add_frame(self, frame):
        """Add a frame to the video"""
        # Most gym environments with rgb_array mode return RGB format
        # Convert to BGR for video writing (OpenCV expects BGR)
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        
        # Initialize writer if not done
        if self.writer is None:
            self.setup_writer(frame_bgr.shape)
        
        # Write frame
        self.writer.write(frame_bgr)
        self.frame_count += 1
        
    def finish(self):
        """Close video writer"""
        if self.writer:
            self.writer.release()
            print(f"Video saved as {self.output_path} with {self.frame_count} frames")

def get_frame_from_env(env):
    """Get the current frame from environment"""
    try:
        # Try to get frame using render
        frame = env.render()
        if frame is not None:
            return frame
    except:
        pass
    
    # Alternative: try to get frame from environment's internal state
    try:
        if hasattr(env, 'get_frame'):
            return env.get_frame()
        elif hasattr(env, 'render_frame'):
            return env.render_frame()
    except:
        pass
    
    # If no frame available, create a placeholder
    return np.zeros((480, 640, 3), dtype=np.uint8)

def record_video():
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    print(f"{RED} args: {args} {ENDC}")
    logger.debug(f"args: {args}")
    print(logger)
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode="rgb_array", camera_id=1)()  # Use rgb_array for frame capture
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 
    # Initialize video recorder
    video_recorder = VideoRecorder(f"{'/'.join(args.path.split('/')[:-3])}/video_f{timestamp}.mp4", fps=24)
    
    with torch.no_grad():
        for episode in range(1):
            o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
            
            while not (d):
                a, *_ = ac.pi(
                    torch.as_tensor(
                        np.expand_dims(o, axis=0), dtype=torch.float32 
                    ).to(device), 
                )
                
                a = a.cpu().numpy()[0]
                
                # Get current frame
                frame = get_frame_from_env(env)
                
                # Add frame to video (no overlay)
                video_recorder.add_frame(frame)
                
                # Step environment
                o2, r, d, _, info = env.step(a)
                o = o2
                ep_len += 1
                
                # Optional: limit recording duration
                if ep_len > 1000:  # Record first 1000 steps
                    break
                    
    # Finish video recording
    video_recorder.finish()
    print(f"Episode length: {ep_len}")

def record_video_one_per_component():
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    logger.debug(f"args: {args}")
    print(logger)
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    n_components = min(ac.pi.n_components, 16)  # Max 4x4 = 16 components
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    print(f"ac: {ac}")
    # Create environments for each component
    envs = []
    max_ep_lens = []
    for k in range(n_components):
        env = build_env(args, render_mode="rgb_array", camera_id=0)()
        envs.append(env)
        max_ep_lens.append(500)
    
    # Initialize video recorder
    video_recorder = VideoRecorder(f"{'/'.join(args.path.split('/')[:-3])}/video_grid_{timestamp}.mp4", fps=24)
    
    with torch.no_grad():
        # Initialize all environments
        obs = []
        dones = []
        ep_lens = []
        for env in envs:
            o, d, ep_len = env.reset()[0], False, 0
            obs.append(o)
            dones.append(d)
            ep_lens.append(ep_len)
        
        while not all(dones) and max(ep_lens) < max(max_ep_lens):
            frames = []
            
            # Get frames from each component
            for k in range(n_components):
                if not dones[k] and ep_lens[k] < max_ep_lens[k]:
                    a, *_ = ac.pi(
                        torch.as_tensor(
                            np.expand_dims(obs[k], axis=0), dtype=torch.float32 
                        ).to(device),
                        manual_indices=torch.tensor([k]).to(device),
                        # bias_config={"std": 0.4},
                        deterministic=False,
                    )
                    
                    a = a.cpu().numpy()[0]
                    
                    # Get current frame and resize to 600x600
                    frame = get_frame_from_env(envs[k])
                    frame = np.array(Image.fromarray(frame).resize((600, 600)))
                    frames.append(frame)
                    
                    # Step environment
                    o2, r, d, _, info = envs[k].step(a)
                    obs[k] = o2
                    dones[k] = d or ep_lens[k] >= max_ep_lens[k]
                    ep_lens[k] += 1
                else:
                    # Black frame when episode is done
                    black_frame = np.zeros((600, 600, 3), dtype=np.uint8)
                    frames.append(black_frame)
                    dones[k] = True
            
            # Create 4x4 grid
            grid_frame = create_grid_frame(frames, 4, 4)
            video_recorder.add_frame(grid_frame)
                    
    # Finish video recording
    video_recorder.finish()
    print(f"Episode lengths: {max_ep_lens}")

def create_grid_frame(frames, rows, cols):
    if not frames:
        return np.zeros((480*rows, 640*cols, 3), dtype=np.uint8)
    
    frame_height, frame_width = frames[0].shape[:2]
    grid_height = frame_height * rows
    grid_width = frame_width * cols
    grid_frame = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
    
    for i in range(min(len(frames), rows * cols)):
        row = i // cols
        col = i % cols
        y_start = row * frame_height
        y_end = y_start + frame_height
        x_start = col * frame_width
        x_end = x_start + frame_width
        grid_frame[y_start:y_end, x_start:x_end] = frames[i]
    
    return grid_frame

def env_random_action_eplen_report():
    # Run 100 episodes on 10 different environment instances and report statistics
    args = read_args()
    args.test = True
    print(f"Running with args: {args}")
    setup_logging(args.log_level) 
    torch.set_num_threads(torch.get_num_threads())
    
    # Store results for all environments
    all_env_results = []
    all_episode_lengths = []
    all_episode_returns = []
    
    print("Testing 10 different environment instances...")
    
    for env_idx in range(10):
        print(f"\n=== Environment {env_idx + 1}/10 ===")
        
        # Initialize new environment instance
        env = build_env(args, render_mode=None)()
        
        # Store episode lengths and returns for this environment
        episode_lengths = []
        episode_returns = []
        
        print("Running 100 episodes with random actions...")
        
        with torch.no_grad():
            for episode in range(10):
                o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
                
                while not d and ep_len < 1000:
                    a = env.action_space.sample()
                    o2, r, d, _, info = env.step(a)
                    o = o2
                    ep_len += 1
                    ep_ret += r
                
                episode_lengths.append(ep_len)
                episode_returns.append(ep_ret)
                
                # Print progress every 50 episodes for brevity
                if (episode + 1) % 50 == 0:
                    print(f"  Completed {episode + 1}/100 episodes")
        
        # Calculate statistics for this environment
        import math
        
        avg_ep_length = sum(episode_lengths) / len(episode_lengths)
        avg_ep_return = sum(episode_returns) / len(episode_returns)
        min_ep_length = min(episode_lengths)
        max_ep_length = max(episode_lengths)
        min_ep_return = min(episode_returns)
        max_ep_return = max(episode_returns)
        
        # Calculate standard deviation
        n = len(episode_lengths)
        ep_length_variance = sum((x - avg_ep_length)**2 for x in episode_lengths) / (n - 1)
        ep_length_std = math.sqrt(ep_length_variance)
        ep_return_variance = sum((x - avg_ep_return)**2 for x in episode_returns) / (n - 1)
        ep_return_std = math.sqrt(ep_return_variance)
        
        # Store results for this environment
        env_results = {
            'env_id': env_idx + 1,
            'episode_lengths': episode_lengths,
            'episode_returns': episode_returns,
            'avg_episode_length': avg_ep_length,
            'avg_episode_return': avg_ep_return,
            'ep_length_std': ep_length_std,
            'ep_return_std': ep_return_std,
            'min_ep_length': min_ep_length,
            'max_ep_length': max_ep_length,
            'min_ep_return': min_ep_return,
            'max_ep_return': max_ep_return
        }
        
        all_env_results.append(env_results)
        all_episode_lengths.extend(episode_lengths)
        all_episode_returns.extend(episode_returns)
        
        print(f"  Env {env_idx + 1} Results:")
        print(f"    Avg episode length: {avg_ep_length:.2f} (std: {ep_length_std:.2f})")
        print(f"    Avg episode return: {avg_ep_return:.2f} (std: {ep_return_std:.2f})")
        
        # Clean up environment
        env.close()
        del env
    
    # Calculate overall statistics across all environments
    print(f"\n=== Overall Results Across All 10 Environments ===")
    
    # Statistics from individual environment averages
    env_avg_lengths = [result['avg_episode_length'] for result in all_env_results]
    env_avg_returns = [result['avg_episode_return'] for result in all_env_results]
    
    mean_of_env_lengths = sum(env_avg_lengths) / len(env_avg_lengths)
    mean_of_env_returns = sum(env_avg_returns) / len(env_avg_returns)
    
    # Calculate variance across environment averages
    length_var_across_envs = sum((x - mean_of_env_lengths)**2 for x in env_avg_lengths) / (len(env_avg_lengths) - 1)
    length_std_across_envs = math.sqrt(length_var_across_envs)
    length_se_across_envs = length_std_across_envs / math.sqrt(len(env_avg_lengths))
    
    return_var_across_envs = sum((x - mean_of_env_returns)**2 for x in env_avg_returns) / (len(env_avg_returns) - 1)
    return_std_across_envs = math.sqrt(return_var_across_envs)
    return_se_across_envs = return_std_across_envs / math.sqrt(len(env_avg_returns))
    
    # Statistics from all episodes combined
    total_episodes = len(all_episode_lengths)
    overall_avg_length = sum(all_episode_lengths) / total_episodes
    overall_avg_return = sum(all_episode_returns) / total_episodes
    
    overall_length_var = sum((x - overall_avg_length)**2 for x in all_episode_lengths) / (total_episodes - 1)
    overall_length_std = math.sqrt(overall_length_var)
    overall_length_se = overall_length_std / math.sqrt(total_episodes)
    
    overall_return_var = sum((x - overall_avg_return)**2 for x in all_episode_returns) / (total_episodes - 1)
    overall_return_std = math.sqrt(overall_return_var)
    overall_return_se = overall_return_std / math.sqrt(total_episodes)
    
    print(f"Total episodes across all environments: {total_episodes}")
    print(f"\nEnvironment-to-Environment Variance:")
    print(f"  Mean episode length across envs: {mean_of_env_lengths:.2f} ± {length_se_across_envs:.2f} (SE)")
    print(f"  Std dev of env averages: {length_std_across_envs:.2f}")
    print(f"  Min env avg length: {min(env_avg_lengths):.2f}")
    print(f"  Max env avg length: {max(env_avg_lengths):.2f}")
    print(f"  Mean episode return across envs: {mean_of_env_returns:.2f} ± {return_se_across_envs:.2f} (SE)")
    print(f"  Std dev of env averages: {return_std_across_envs:.2f}")
    print(f"  Min env avg return: {min(env_avg_returns):.2f}")
    print(f"  Max env avg return: {max(env_avg_returns):.2f}")
    
    print(f"\nOverall Statistics (all {total_episodes} episodes):")
    print(f"  Overall avg episode length: {overall_avg_length:.2f} ± {overall_length_se:.2f} (SE)")
    print(f"  Overall episode length std: {overall_length_std:.2f}")
    print(f"  Overall avg episode return: {overall_avg_return:.2f} ± {overall_return_se:.2f} (SE)")
    print(f"  Overall episode return std: {overall_return_std:.2f}")
    
    return {
        'individual_env_results': all_env_results,
        'all_episode_lengths': all_episode_lengths,
        'all_episode_returns': all_episode_returns,
        'env_to_env_length_variance': length_std_across_envs,
        'env_to_env_return_variance': return_std_across_envs,
        'overall_avg_length': overall_avg_length,
        'overall_avg_return': overall_avg_return,
        'mean_of_env_lengths': mean_of_env_lengths,
        'mean_of_env_returns': mean_of_env_returns
    }
    
def record_video_with_random_component():
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    print(f"{RED} args: {args} {ENDC}")
    logger.debug(f"args: {args}")
    print(logger)
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode="rgb_array", camera_id=1)()  # Use rgb_array for frame capture
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 
    # Initialize video recorder
    video_recorder = VideoRecorder(f'{args.env}_agent_video_uniform_random_component_{timestamp}.mp4', fps=24)
    
    with torch.no_grad():
        o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
        
        while not (d):
            k = random.randint(0, args.n_components -1)
            print(f"Using random component: {k}")
        
            a, *_ = ac.pi(
                torch.as_tensor(
                    np.expand_dims(o, axis=0), dtype=torch.float32 
                ).to(device),
                manual_indices=torch.tensor([k]).to(device),
                deterministic=True,  
            )
            
            a = a.cpu().numpy()[0]
            
            # Get current frame
            frame = get_frame_from_env(env)
            
            # Add frame to video (no overlay)
            video_recorder.add_frame(frame)
            
            # Step environment
            o2, r, d, _, info = env.step(a)
            o = o2
            ep_len += 1
            
            # Optional: limit recording duration
            if ep_len > 1000:  # Record first 1000 steps
                break
                
    # Finish video recording
    video_recorder.finish()
    print(f"Episode length: {ep_len}") 

def live_uniform_components():
    args = read_args()
    args.test=True
    setup_logging(args.log_level)
    logger.debug(f"args: {args}")
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode="human", camera_id=1)()
    print(f"ac: {ac}") 
    print(f"ac.pi.n_components: {ac.pi.n_components}")
    with torch.no_grad():
        for _ in range(1):
            o, d, ep_len = env.reset()[0], False, 0
            while not (d) and ep_len < 10000:
                k = random.randint(0, args.n_components -1)
                # k=0
                for _ in range(1):
                    a, _, _, _, _, _, info = ac.pi(
                        torch.as_tensor(
                            np.expand_dims(o, axis=0), dtype=torch.float32 
                        ).to(device), 
                        # deterministic=False,
                        manual_indices=torch.tensor([k]).to(device),
                        # bias_config={"std":0.01},
                        # deterministic=True,
                    )
                    # print(info['mixing_probs'].cpu().numpy().round(3))
                    print(info["indices"].cpu().numpy(), a.cpu().numpy()[0].round(2))
                    a = a.cpu().numpy()[0]
                    o2, r, d, _, info = env.step(a)
                    o = o2
                    ep_len += 1
    print(ep_len)

def live_uniform_components_with_distributions():
    args = read_args()
    args.test=True
    setup_logging(args.log_level)
    logger.debug(f"args: {args}")
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode="human")()
    n_components = ac.pi.n_components 
    with torch.no_grad():
        for _ in range(1):
            o, d, ep_len = env.reset()[0], False, 0
            while not (d) and ep_len < 10000:
                k = random.randint(0, n_components - 1)
                for _ in range(1):
                    a, _, _, _, _, _, info = ac.pi(
                        torch.as_tensor(
                            np.expand_dims(o, axis=0), dtype=torch.float32 
                        ).to(device), 
                        manual_indices=torch.tensor([k]).to(device),
                        deterministic=False,
                    )
                    
                    # Print basic info
                    print(f"Component: {info['indices'].cpu().numpy()}, Action: {a.cpu().numpy()[0].round(2)}")
                    
                    # Print all mus and stds
                    all_mus = info['all_mus'].cpu().numpy()[0]  # [n_components, act_dim]
                    all_stds = info['all_stds'].cpu().numpy()[0]  # [n_components, act_dim]
                    print(f"{GREEN} {all_mus.round(2)} {ENDC}") 
                    # Optional: Add distribution analysis similar to log_allmus
                    for i in range(all_mus.shape[1]):  # For each action dimension
                        mus_dim = np.tanh(all_mus[:, i])
                        stds_dim = all_stds[:, i]
                        mu_range = mus_dim.max() - mus_dim.min()
                        std_range = stds_dim.max() - stds_dim.min()
                        print(f"{BLUE}  Mu: Dim {i}, {sorted(mus_dim.round(2))}{ENDC}")
                        print(f"{GREEN} Sigma: {sorted(stds_dim.round(2))} {ENDC}")
                        print(f"{YELLOW} Dim {i}: mu_range={mu_range:.3f}, std_range={std_range:.3f}{ENDC}")
                    
                    print("-" * 50)
                    
                    a = a.cpu().numpy()[0]
                    o2, r, d, _, info = env.step(a)
                    o = o2
                    ep_len += 1
    print(ep_len)

def plot_component_trajectories():
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode=None)()
    n_components = ac.pi.n_components  
    # Store trajectories for each component
    component_trajectories = {}
    components_to_plot = range(n_components) 
    with torch.no_grad():
        for k in components_to_plot:
            print(f"Running component {k}/{args.n_components}")
            component_trajectories[k] = []
            
            for episode in range(5):
                print(f"  Episode {episode+1}/10")
                o, d, ep_len = env.reset()[0], False, 0
                episode_positions = []
                
                while not d and ep_len < 200:
                    a, *_ = ac.pi(
                        torch.as_tensor(
                            np.expand_dims(o, axis=0), dtype=torch.float32 
                        ).to(device),
                        manual_indices=torch.tensor([k]).to(device),
                        deterministic=True,
                    )
                    
                    a = a.cpu().numpy()[0]
                    o2, r, d, _, info = env.step(a)
                    
                    # Extract x,y position from info (try different possible keys)
                    x_pos, y_pos = None, None
                    if 'x_position' in info and 'y_position' in info:
                        x_pos, y_pos = info['x_position'], info['y_position']
                    elif 'qpos' in info:
                        x_pos, y_pos = info['qpos'][0], info['qpos'][1]
                    elif hasattr(env, 'get_body_com'):
                        pos = env.get_body_com("torso")
                        x_pos, y_pos = pos[0], pos[1]
                    
                    if x_pos is not None and y_pos is not None:
                        episode_positions.append([x_pos, y_pos])
                    
                    o = o2
                    ep_len += 1
                
                if episode_positions:
                    component_trajectories[k].append(np.array(episode_positions))
    
    # Create subplots
    n_cols = min(4, len(components_to_plot))
    n_rows = (len(components_to_plot) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    if n_rows == 1 and n_cols == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    for n, k in enumerate(components_to_plot):
        ax = axes[n]
        for episode_traj in component_trajectories[k]:
            if len(episode_traj) > 0:
                line = ax.plot(episode_traj[:, 0], episode_traj[:, 1], 
                              alpha=0.7, linewidth=1)[0]
                line_color = line.get_color()
                # Add start point (circle) and end point (x) with same color
                ax.plot(episode_traj[0, 0], episode_traj[0, 1], 'o', 
                       color=line_color, markersize=6)
                ax.plot(episode_traj[-1, 0], episode_traj[-1, 1], 'x', 
                       color=line_color, markersize=8)
        
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
        ax.set_title(f'Component {k}')
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(len(components_to_plot), len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(f"{'/'.join(args.path.split('/')[:-3])}/traj_diffent_components.png")   
    return component_trajectories

def record_frames():
    """
    Run one episode with the loaded policy and save per-step RGB frames
    as PNG images in a timestamped folder (instead of recording a video file).
    """
    args = read_args()
    args.test = True
    setup_logging(args.log_level)
    logger.debug(f"args: {args}")
    set_seed(args.seed)
    torch.set_num_threads(torch.get_num_threads())

    # load actor/critic (policy) - keep your existing device mapping
    ac = torch.load(args.path, map_location=device)

    # build environment for offscreen rendering
    env = build_env(args, render_mode="rgb_array", camera_id=1)()

    # create output directory next to your checkpoint (same logic as original)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_base = '/'.join(args.path.split('/')[:-3]) or '.'
    out_dir = os.path.join(checkpoint_base, f"frames_f{timestamp}")
    os.makedirs(out_dir, exist_ok=True)

    # metadata file
    meta_path = os.path.join(out_dir, "metadata.txt")
    with open(meta_path, "w") as mf:
        mf.write(f"args: {args}\n")
        mf.write(f"timestamp: {timestamp}\n")
        mf.write("notes: frames saved as frame_000000.png ...\n")

    print(f"Saving frames to: {out_dir}")
    logger.info(f"Saving frames to: {out_dir}")

    with torch.no_grad():
        for episode in range(1):
            # match your reset API (env.reset()[0] used in original)
            o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
            frame_idx = 0

            while not d:
                k = random.randint(0, args.n_components -1)
                # compute action with your policy
                a, *_ = ac.pi(
                    torch.as_tensor(np.expand_dims(o, axis=0), dtype=torch.float32).to(device),
                    manual_indices=torch.tensor([k]).to(device),
                    # deterministic=True,  # or False if you want stochastic actions
                )
                a = a.cpu().numpy()[0]

                # get current frame from environment (expects HxWxC numpy array or similar)
                frame = get_frame_from_env(env)

                # normalize/convert frame to uint8 HxWxC
                arr = np.asarray(frame)
                if arr.ndim == 2:  # grayscale -> convert to RGB
                    arr = np.stack([arr]*3, axis=-1)
                if arr.dtype != np.uint8:
                    # handle floats in [0,1] and other numeric ranges
                    maxv = arr.max() if arr.size > 0 else 1.0
                    if maxv <= 1.0:
                        arr = (np.clip(arr, 0.0, 1.0) * 255.0).round().astype(np.uint8)
                    else:
                        arr = np.clip(arr, 0, 255).astype(np.uint8)

                # save PNG with zero-padded index
                fname = os.path.join(out_dir, f"frame_{frame_idx:06d}.png")
                if ep_len % 10 == 0:
                    Image.fromarray(arr).save(fname, format="PNG")
                frame_idx += 1

                # step environment
                o2, r, d, _, info = env.step(a)
                o = o2
                ep_len += 1
                ep_ret += float(r) if np.isscalar(r) else np.sum(r)

                # safety cap
                if ep_len > 1000:
                    logger.info("Reached max record length (1000); stopping.")
                    break

    # write summary metadata
    with open(meta_path, "a") as mf:
        mf.write(f"saved_frames: {frame_idx}\n")
        mf.write(f"episode_length: {ep_len}\n")
        mf.write(f"episode_return: {ep_ret}\n")

    env.close()
    print(f"Episode finished. saved {frame_idx} frames to {out_dir}")
    logger.info(f"Episode finished. saved {frame_idx} frames to {out_dir}")


def count_k_usage_100_episodes():
    """Count different k values used across 100 episodes"""
    args = read_args()
    args.test=True
    setup_logging(args.log_level)
    logger.debug(f"args: {args}")
    set_seed(args.seed) 
    torch.set_num_threads(torch.get_num_threads())
    ac = torch.load(args.path, map_location=device) 
    env = build_env(args, render_mode="rgb_array")()
    n_components = ac.pi.n_components 
     
    k_counts = {}
    
    with torch.no_grad():
        for episode in range(10):
            print(f"Episode {episode+1}/10")
            o, d, ep_len = env.reset()[0], False, 0
            while not d and ep_len < 1000:
                a, logp_a, pi, mu, std, std, info = ac.pi(
                    torch.as_tensor(
                        np.expand_dims(o, axis=0), dtype=torch.float32 
                    ).to(device),
                )
                k = info["indices"].cpu().numpy()[0]
                k_counts[k] = k_counts.get(k, 0) + 1
                a = a.cpu().numpy()[0]
                o2, r, d, _, info = env.step(a)
                o = o2
                ep_len += 1
    print(f"K usage counts over 100 episodes: {k_counts}") 
    total = sum(k_counts.values())
    for k in sorted(k_counts):
        percent = 100.0 * k_counts[k] / total if total > 0 else 0.0
        print(f"Component {k}: {k_counts[k]} times ({percent:.2f}%)")
    
    return k_counts
 
if __name__ == "__main__":
    # record_video_with_praw()
    # Env_random_action_live()
    # create_high_quality_multi_setting_video()
    # record_video()
    # record_video_one_per_component()
    # env_random_action_eplen_report()
    # record_video_with_random_component()
    live_uniform_components()
    # live_uniform_components_with_distributions()
    # record_frames()
    # count_k_usage_100_episodes()

    