import numpy as np
import random
import pandas as pd
import os
import math
import numpy as np
import imageio
# import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

import time
import os.path as osp

from copy import deepcopy
import itertools
import numpy as np
import torch
import time
import os
from logx import EpochLogger, setup_logger_kwargs

import os
import imageio
import numpy as np
from PIL import Image
import PIL.ImageDraw as ImageDraw
import matplotlib.pyplot as plt
import argparse
from warnings import filterwarnings

import logging

import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)

import os
from dotenv import load_dotenv
load_dotenv()


# Colors
RED = "\033[91m"
GREEN = "\033[92m"
BLUE = "\033[94m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
MAGENTA = "\033[95m"
ENDC = "\033[0m"  # Reset color

logger = logging.getLogger(__name__)

def setup_logging(log_level):
    print(log_level)
    # Convert string log level to logging constant
    numeric_level = getattr(logging, log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError(f"Invalid log level: {log_level}")
    
    # For Python 3.8 compatibility, clear existing handlers first
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    
    # Use only one basicConfig call
    logging.basicConfig(
        level=numeric_level,
        format="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
        force=True  # This parameter works in Python 3.10+ but is ignored in 3.8
    )
    
    # For Python 3.8, explicitly set level on your logger
    logger.setLevel(numeric_level)
    
    # Make sure this logger has at least one handler
    if not logger.handlers and not logger.parent.handlers:
        handler = logging.StreamHandler()
        handler.setLevel(numeric_level)
        formatter = logging.Formatter("%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s")
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    
    # Verify logger level after setup
    print(f"Logger after setup: {logger}")

filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message="`np.bool8` is a deprecated alias for `np.bool_`",
)


device = os.getenv("device") if os.getenv("device") else "cuda" if torch.cuda.is_available() else "cpu" 
# torch.device(os.getenv("device"))
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def convert_to_underscore(num):
    return f"{num:.2f}".replace(".", "_")

def format_number(num):
    """
    Convert numbers to a specific string format:
    - Converts decimals less than 1 to format with underscore before digits
    - Keeps whole numbers as is
    - Adds underscore before 0 for whole numbers with decimal point
    
    Examples:
    0.2 -> "0_2"
    0.002 -> "0_002"
    20 -> "20"
    20.0 -> "20_0"
    """
    str_num = str(num)
    
    if '.' in str_num:
        whole, decimal = str_num.split('.')
        if whole == '0':
            return f"0_{decimal}"
        else:
            return f"{whole}_{decimal}"
    return str_num

def get_last_file_number(directory):
    files = os.listdir(directory)
    max_number = -1
    for file in files:
        if file.startswith("info_") and file.endswith(".csv"):
            try:
                number = int(file.split("_")[1].split(".")[0])
                if number > max_number:
                    max_number = number
            except ValueError:
                continue
    return max_number

def test_model_human_model_no_log_file(
    env, model, path: str, render_mode="human", max_eps_length=1000, deterministic=False
):
    h = False
    if render_mode == "human":
        h = True
    o, d, ep_ret, ep_len = env.reset()[0], False, 0, 0
    last_x, last_y = 0, 0

    while not (d) and ep_len < max_eps_length:
        # Take deterministic actions at test time
        a = model.get_action(
            torch.as_tensor(o, dtype=torch.float32), deterministic=deterministic
        )
        # print(a)
        o, r, d, _, info = env.step(a)
        print(o[0])
        if o[0] > 0.9:
            time.sleep(10)
            # input("Enter your value: ")

        x_pos = info.get("x_position", 0)
        y_pos = info.get("y_position", 0)

        # print(f"x_pos - last_x: {(x_pos - last_x):.4f}, y_pos - last_y: {(y_pos - last_y):.4f}")

        x_pos = last_x
        y_pos = last_y

        ep_ret += r
        ep_len += 1
    print(ep_len)
    print(env.energy)

SAVE_FORM_NO = 0
SAVE_FORM_VIDEO = 1
SAVE_FORM_FRAME = 2

def save_frames(frame_dir, frames):
    if not os.path.exists(frame_dir):
        os.makedirs(frame_dir)

    for i, frame in enumerate(frames):
        frame_path = os.path.join(frame_dir, f"frame_{i:04d}.png")
        imageio.imwrite(frame_path, frame)

def save_mp4_video(video_dir, video_name, frames):
    # Ensure the video directory exists
    if not os.path.exists(video_dir):
        os.makedirs(video_dir)

    video_path = os.path.join(video_dir, f"{video_name}.mp4")

    # Define the writer object with the required settings for MP4
    with imageio.get_writer(
        video_path, fps=30, codec="libx264", format="mp4", mode="I"
    ) as writer:
        for frame in frames:
            writer.append_data(frame)

def expand_array_to_columns(data, prefix):
    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()
    
    if not isinstance(data, np.ndarray):
        raise TypeError("Input must be a NumPy array or PyTorch tensor.")
    
    data = np.squeeze(data)
    
    if data.size == 1:
        value = data.item() if hasattr(data, 'item') else data
        return {prefix: value}
        
    if data.ndim > 2:
        raise ValueError("Input array must be 1D or 2D after squeezing.")
        
    if data.ndim == 1:
        return {f"{prefix}_{i}": data[i] for i in range(len(data))}
    
    return {
        f"{prefix}_{row}_{col}": data[row, col]
        for row in range(data.shape[0])
        for col in range(data.shape[1])
    }
    
def prep_action_info_for_save(action_info):
    """Prepares the action info dictionary for saving to a CSV file."""
    expanded_data = {}
    for k in ['a', 'pi', 'mu', 'std', 'cov']:
        if k not in action_info:
            continue
        expanded_data.update(expand_array_to_columns(action_info[k], k))
    
    expanded_data["logp_a"] = action_info["logp_a"][0]
    
    if "activation_outputs" in action_info:
        for k in action_info["activation_outputs"]:
            expanded_data.update(expand_array_to_columns(action_info["activation_outputs"][k], k))
            
    return expanded_data

def parse_layer_sizes(sizes_str):
    """Convert comma-separated string of integers to list of layer sizes."""
    try:
        return [int(size) for size in sizes_str.split(',')]
    except ValueError:
        raise argparse.ArgumentTypeError(
            "Layer sizes must be comma-separated integers (e.g., '256,128,64')"
        )

def save_video_grid(frames_list, output_path, fps=30, grid_cols=None, frame_height=None, frame_width=None):
    """
    Creates a grid video from a list of frame sequences and saves it to the specified path.
    
    Args:
        frames_list: List of lists, where each inner list contains frames for one video
        output_path: Path where the output video will be saved
        fps: Frames per second for the output video (default: 30)
        grid_cols: Number of columns in the grid (default: automatically calculated)
        frame_height: Height of each frame (default: use height from first frame)
        frame_width: Width of each frame (default: use width from first frame)
    
    Returns:
        The path to the saved video file
    """
    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Handle empty input
    if not frames_list or len(frames_list) == 0:
        logging.warning("No frames provided to save_video_grid")
        return None
    
    # Calculate grid dimensions
    num_videos = len(frames_list)
    
    # Determine grid columns (default to square-ish grid)
    if grid_cols is None:
        grid_cols = min(3, num_videos) #math.ceil(math.sqrt(num_videos))
    
    # Calculate rows needed
    grid_rows = math.ceil(num_videos / grid_cols)
    
    # Calculate total videos needed for full grid
    total_videos_needed = grid_cols * grid_rows
    
    # Determine frame dimensions if not provided
    if frame_height is None or frame_width is None:
        if frames_list and frames_list[0] and len(frames_list[0]) > 0:
            sample_frame = frames_list[0][0]
            if frame_height is None:
                frame_height = sample_frame.shape[0]
            if frame_width is None:
                frame_width = sample_frame.shape[1]
        else:
            # Default dimensions if no frames are available
            frame_height = frame_height or 480
            frame_width = frame_width or 640
    
    # Pad with empty frames if necessary to fill the grid
    while len(frames_list) < total_videos_needed:
        # Create a black frame of the same dimensions
        empty_frame = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
        if frames_list and frames_list[0]:
            # Make the empty sequence the same length as the first sequence
            frames_list.append([empty_frame] * len(frames_list[0]))
        else:
            # If no frames are available, create a short empty video
            frames_list.append([empty_frame] * 10)
    
    # Find the maximum number of frames across all videos
    max_frames = max(len(frames) for frames in frames_list)
    
    # Pad shorter videos with their last frame
    for i in range(len(frames_list)):
        if len(frames_list[i]) < max_frames:
            last_frame = frames_list[i][-1] if frames_list[i] else np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
            frames_list[i].extend([last_frame] * (max_frames - len(frames_list[i])))
    
    # Calculate grid dimensions
    grid_width = grid_cols * frame_width
    grid_height = grid_rows * frame_height
    
    # Create grid frames
    grid_frames = []
    for frame_idx in range(max_frames):
        # Create an empty grid frame
        grid_frame = np.zeros((grid_height, grid_width, 3), dtype=np.uint8)
        
        # Fill the grid with frames from each video
        for video_idx in range(min(total_videos_needed, len(frames_list))):
            row = video_idx // grid_cols
            col = video_idx % grid_cols
            
            # Get the frame
            video_frame = frames_list[video_idx][frame_idx]
            
            # Ensure the frame has the right dimensions
            if video_frame.shape[:2] != (frame_height, frame_width):
                # Resize frame to match expected dimensions
                try:
                    import cv2
                    video_frame = cv2.resize(video_frame, (frame_width, frame_height))
                except ImportError:
                    # Fallback if cv2 is not available
                    video_frame = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
            
            # Place the frame in the grid
            y_start = row * frame_height
            y_end = (row + 1) * frame_height
            x_start = col * frame_width
            x_end = (col + 1) * frame_width
            
            grid_frame[y_start:y_end, x_start:x_end] = video_frame
        
        grid_frames.append(grid_frame)
    # Save the grid video
    # try:
    # with imageio.get_writer(output_path, fps=fps, codec="libx264", format="mp4", mode="I") as writer:
    #     for frame in grid_frames:
    #         print(frame)
    #         writer.append_data(frame)
        
    #     logging.info(f"Grid video saved to {output_path}")
    #     return output_path
    # # except Exception as e:
    #     logging.error(f"Error saving grid video: {str(e)}")
    #     return None
    with imageio.get_writer(output_path, fps=fps, codec="libx264", format="mp4", mode="I") as writer:
        for frame in grid_frames:
            # print(frame)
            # Only pass the frame data, not any additional parameters
            writer.append_data(frame)
        
        logging.info(f"Grid video saved to {output_path}")
        return output_path 

def read_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="Ant-v4")
    
    parser.add_argument(
        '--log_level',
        default='INFO',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        help='Set the logging level'
    )
    
    parser.add_argument("--path", type=str)
    parser.add_argument("--bias_config_path", type=str)
    
    parser.add_argument(
        "--layers",
        type=parse_layer_sizes,
        default="256,256,256",
        help="Comma-separated list of layer sizes (e.g., '256,128,64')"
    )
     
    parser.add_argument("--exp_name", type=str, default="test")
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--seed", "-s", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--alpha", type=float, default=1.0)
    parser.add_argument("--epsilon", type=float, default=2)
    parser.add_argument("--eplen", type=float, default=200)
    parser.add_argument("--dropout_rate", type=float, default=0.1)
    parser.add_argument("--early_stopping", action="store_true")
     
    parser.add_argument("--steps_per_epoch", type=int, default=2000)
    parser.add_argument("--start_steps", type=int, default=2000)
    
    
    parser.add_argument("--num_workers", type=int, default=4)
    
    
    parser.add_argument("--n_components", type=int, default=10)
    parser.add_argument("--random_component", action="store_true")
    parser.add_argument("--component", type=int, default=-1)
    
    # MM args
    parser.add_argument("--repulsion_beta", type=float, default=0.1)
    parser.add_argument("--repulsion_lambda", type=float, default=0.1)
     
    # VQ-VAE args
    parser.add_argument("--vq", action="store_true")
    parser.add_argument("--vq_num_emb", type=int, default=256)
    parser.add_argument("--vq_cd", type=float, default=2.0)
    parser.add_argument("--vq_beta", type=float, default=0.25)
    parser.add_argument("--vq_lambda_reg", type=float, default=0.1)
    parser.add_argument("--vq_temp", type=float, default=1.0)
    parser.add_argument("--vq_hard", action="store_true")
    
    # HRL
    parser.add_argument("--hrl_test", action="store_true")
    parser.add_argument("--hrl_hard", action="store_true")
    parser.add_argument("--hrl_exp_name", type=str)
    parser.add_argument("--hrl_model_path", type=str)
    parser.add_argument("--hrl_timesteps", type=int, default=1_000_000)
    parser.add_argument("--hrl_std", type=float, default=0.2)
    parser.add_argument("--hrl_nc", type=int, default=64)
    parser.add_argument("--hrl_eval_freq", type=int, default=10_000)
    parser.add_argument("--hrl_config_name", type=str, default="default")
    parser.add_argument("--hrl_algo")
    parser.add_argument("--hrl_test_episodes", type=int, default=9)
    parser.add_argument("--hrl_norm_rewards", action="store_true")
    parser.add_argument("--hrl_rewards_scale", type=float, default=1.0)
    parser.add_argument("--hrl_csv_dir", type=str)
    parser.add_argument("--hrl_action_magnitude", type=float, default=0.5) 
    # hrl_gmean hrl_gstd
    parser.add_argument("--hrl_gmean", type=float, default=0.0)
    parser.add_argument("--hrl_gstd", type=float, default=0.2)
    
    # Food
    parser.add_argument("--env_rd_action_magnitude", type=float, default=0.2)
    parser.add_argument("--env_rd_num_actions", type=int, default=64)
    #env_rd_mean
    parser.add_argument("--env_rd_mean", type=float, default=0.0)
    parser.add_argument("--env_rd_std", type=float, default=0.2)
     
    ## ENV
    parser.add_argument("--boundary_limit", type=int, default=1)
     
    parser.add_argument("--model", type=str)
    parser.add_argument("--ac_model", type=str)

    parser.add_argument("--test", action="store_true")
    args = parser.parse_args()
    return args

def format_tensor(tensor):
    np.set_printoptions(precision=2, suppress=True, linewidth=120)
    formatted = str(tensor).replace('tensor(', '').replace(')', '')
    return formatted

def get_env_dims(env):
    """
    Extract observation dimension, action dimension, and action limit from a Gymnasium environment.
    
    Args:
        env: A Gymnasium environment
        
    Returns:
        obs_dim: Dimension of the observation space
        act_dim: Dimension of the action space
        act_limit: Upper limit of the action space (assumes symmetric bounds)
    """
    # Get action space information
    act_dim = env.action_space.shape[0]
    act_limit = env.action_space.high[0]  # Assumes symmetric bounds
    logger.debug(f"Action space: {act_dim}D, [-{act_limit}, {act_limit}]")
    logger.debug(f"Observation space: {env.observation_space}")
    
    # Check for gymnasium/gym Dict observation space
    if hasattr(env.observation_space, 'spaces'):
        if 'observation' in env.observation_space.spaces:
            # Use the main observation space if available
            obs_dim = env.observation_space.spaces['observation'].shape[0]
        else:
            # If no 'observation' key, concatenate all spaces
            obs_dim = sum(space.shape[0] for space in env.observation_space.spaces.values() 
                          if hasattr(space, 'shape'))
    # Check for Box observation space with direct shape attribute
    elif hasattr(env.observation_space, 'shape') and env.observation_space.shape is not None:
        obs_dim = env.observation_space.shape[0]
    # Fall back for dictionary-like spaces that don't use the spaces attribute
    elif hasattr(env.observation_space, '__getitem__') and 'observation' in env.observation_space:
        obs_dim = env.observation_space['observation'].shape[0]
    else:
        raise ValueError(f"Observation space not recognized: {type(env.observation_space)}")
    
    return obs_dim, act_dim, act_limit