import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import os
from tqdm import tqdm
from typing import List, Tuple, Dict, Any
import jax.numpy as jnp

# Assuming EpisodicDatasetDvrkGeneric and a policy type are available in the scope
# If not, they need to be imported, e.g.:
# from openpi.training.dvrk_dataset import EpisodicDatasetDvrkGeneric
# from openpi.policies.base import TrainedPolicy # Adjust based on actual policy base class


def _get_episode_start_indices(episode_lengths: List[int]) -> List[int]:
    """Calculates the starting flat index for each episode."""
    indices = [0] * len(episode_lengths)
    current_index = 0
    for i, length in enumerate(episode_lengths):
        indices[i] = current_index
        current_index += length
    return indices

def calc_mse_for_single_trajectory_pi0(
    policy: Any, # Replace Any with the actual openpi policy type, e.g., TrainedPolicy
    dataset: Any, # Replace Any with EpisodicDatasetDvrkGeneric
    traj_id: int,
    steps: int = 300,
    action_horizon: int = 16, # Should match dataset's action_horizon ideally
    plot: bool = False,
    save_video: bool = True,
    video_fps: int = 10,
    video_frame_key: str = "observation.images.left", # Key for video frames
):
    """
    Calculates the Mean Squared Error (MSE) for actions over a single episode
    from the EpisodicDatasetDvrkGeneric dataset using an openpi policy
    and optionally saves a video.

    Args:
        policy: The openpi policy model to evaluate.
        dataset: An instance of EpisodicDatasetDvrkGeneric.
        traj_id: The index of the episode to evaluate within the dataset.
        steps: The maximum number of steps to evaluate within the episode.
        action_horizon: The interval at which inference is performed.
        plot: Whether to plot the action trajectories.
        save_video: Whether to save a video of the specified video frame key.
        video_fps: Frames per second for the saved video.
        video_frame_key: The key in the dataset dictionary corresponding to the video frames.
    """
    state_across_time = []
    gt_action_across_time = []
    pred_action_across_time = []
    action_chunk = None # To store the latest predicted action chunk
    video_writer = None # Initialize video writer

    # --- Get episode information ---
    if traj_id < 0 or traj_id >= len(dataset.episode_list):
        raise ValueError(f"traj_id {traj_id} is out of bounds for dataset with {len(dataset.episode_list)} episodes.")

    # Ensure dataset has episode_lengths calculated (should be done in __init__)
    if not hasattr(dataset, 'episode_lengths') or not dataset.episode_lengths:
         raise AttributeError("Dataset must have 'episode_lengths' calculated.")

    episode_lengths = dataset.episode_lengths
    episode_start_indices = _get_episode_start_indices(episode_lengths)

    episode_start_flat_idx = episode_start_indices[traj_id]
    episode_len = episode_lengths[traj_id]
    effective_steps = min(steps, episode_len - 1) # -1 to ensure we don't go out of bounds

    if effective_steps <= 0:
        print(f"Warning: Episode {traj_id} has length {episode_len}. Cannot evaluate for {steps} steps. Skipping.")
        return np.nan

    episode_path, instruction, _ = dataset.episode_list[traj_id]
    ep_name = os.path.basename(episode_path)
    print(f"\nEvaluating episode {traj_id} ({ep_name})")
    print(f"Instruction: {instruction}")
    print(f"Episode Length: {episode_len}. Running for {effective_steps} steps.")
    print(f"Action Horizon (Inference Interval): {action_horizon}")
    # --- End Get episode information ---

    chunked_gt_action = None

    for step_in_episode in tqdm(range(effective_steps), desc=f"Eval Ep {traj_id}"):
        current_flat_index = episode_start_flat_idx + step_in_episode

        # Fetch data using the dataset's __getitem__
        # openpi dataset returns a dictionary of numpy arrays
        try:
            data_point = dataset[current_flat_index]
        except IndexError:
            print(f"Warning: Index {current_flat_index} out of bounds for dataset (len: {len(dataset)}). Stopping trajectory evaluation early.")
            effective_steps = step_in_episode # Adjust effective steps
            break

        transformed_data = policy._input_transform(data_point)
        # --- Video Frame Processing ---
        if save_video:
            frame_data = transformed_data.get("image").get("base_0_rgb")
            if frame_data is not None and isinstance(frame_data, np.ndarray):
                # Expects HWC, float 0-1, RGB
                frame_hwc = frame_data

                # Scale if necessary (assuming float 0-1) and convert type
                if frame_hwc.dtype == np.float32 or frame_hwc.dtype == np.float64:
                     frame_uint8 = (frame_hwc * 255).astype(np.uint8)
                elif frame_hwc.dtype != np.uint8: # Handle other potential types if needed
                     frame_uint8 = frame_hwc.astype(np.uint8)
                else:
                     frame_uint8 = frame_hwc

                # Convert RGB to BGR for OpenCV
                frame_bgr = cv2.cvtColor(frame_uint8, cv2.COLOR_RGB2BGR)

                # Initialize writer on first frame
                if video_writer is None:
                    h, w = frame_bgr.shape[:2]
                    video_filename = f"eval_episode_{traj_id}_{ep_name.replace('.zip','')}_main_video.mp4"
                    os.makedirs("eval_videos", exist_ok=True)
                    video_path = os.path.join("eval_videos", video_filename)
                    fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec
                    video_writer = cv2.VideoWriter(video_path, fourcc, video_fps, (w, h))
                    print(f"Saving video to {video_path} ({w}x{h} @ {video_fps} FPS)")

                video_writer.write(frame_bgr)
            else:
                # Handle case where video frame key is missing or not a numpy array
                if step_in_episode == 0: # Print warning only once
                     print(f"Warning: Frame key '{video_frame_key}' not found or invalid in data point. Cannot save video for episode {traj_id}.")
                     save_video = False # Disable saving for the rest of the episode

        # --- End Video Frame Processing ---

        # Extract state and ground truth action for the current step
        # State might be zeros based on current dvrk_dataset implementation
        current_state = data_point["observation.state"] # Shape (state_dim,)
        # Action from dataset is (action_horizon, action_dim), take [0] for the current step's GT
        current_gt_action = data_point["action"][0] # Shape (action_dim,)

        state_across_time.append(current_state)
        # gt_action_across_time.append(current_gt_action)

        # Perform inference periodically
        if step_in_episode % action_horizon == 0:
            # print(f"Inferencing at step: {step_in_episode} (Flat Index: {current_flat_index})")
            # Policy expects a dictionary similar to dataset output
            # try:
            # We need to add a batch dimension if the policy expects it
            # Assuming policy.infer handles dict of numpy arrays
            # Check policy documentation for exact expected input format
            # policy_input = {k: np.expand_dims(v, axis=0) if isinstance(v, np.ndarray) else v for k, v in data_point.items()}
            # Alternatively, if policy works on non-batched data: policy_input = data_point

            chunked_gt_action = data_point["action"]
            policy_output = policy.infer(data_point) # Input might need batch dim

            # Extract action chunk, assuming output is {'actions': np.ndarray(1, H, D)}
            # Adjust key and shape based on actual policy output
            if "actions" in policy_output and isinstance(policy_output["actions"], np.ndarray):
                    action_chunk = policy_output["actions"]
                    if action_chunk.ndim == 3 and action_chunk.shape[0] == 1: # Remove batch dim
                        action_chunk = action_chunk[0] # Now (H, D)
                    # Ensure the chunk has the expected horizon length
                    if action_chunk.shape[0] != dataset.action_horizon:
                         pass
                        # print(f"Warning: Policy action chunk length ({action_chunk.shape[0]}) doesn't match dataset action horizon ({dataset.action_horizon}) at step {step_in_episode}.")
                        # Handle mismatch: maybe truncate/pad action_chunk or adjust logic
                        # For now, we'll proceed, but indexing might fail later
            else:
                    print(f"Warning: 'actions' key not found or not ndarray in policy output at step {step_in_episode}. Setting action_chunk to None.")
                    action_chunk = None

            # except Exception as e:
            #     print(f"Error during policy inference at step {step_in_episode}: {e}")
            #     action_chunk = None # Ensure action_chunk is None if inference fails

        # Get the predicted action for the current step from the last chunk
        pred_action = np.full_like(current_gt_action, np.nan) # Default to NaN
        gt_action = np.full_like(current_gt_action, np.nan) # Default to NaN
        if action_chunk is not None:
            pred_step_in_chunk = step_in_episode % action_horizon
            # Ensure the prediction step is within the bounds of the current chunk
            if pred_step_in_chunk < action_chunk.shape[0]:
                pred_action = action_chunk[pred_step_in_chunk]
                gt_action = chunked_gt_action[pred_step_in_chunk]
            else:
                 # This case might happen if policy horizon < dataset horizon or due to warnings above
                 print(f"Warning: Prediction step index {pred_step_in_chunk} out of bounds for action chunk with length {action_chunk.shape[0]} at step {step_in_episode}. Using NaNs.")
        # else:
             # If action_chunk is None (first steps or inference failure)
             # print(f"Info: No action chunk available yet for step {step_in_episode}. Using NaNs.") # Can be verbose

        gt_action_across_time.append(gt_action)
        pred_action_across_time.append(pred_action)


    # --- Release Video Writer ---
    if video_writer is not None:
        video_writer.release()
        print("Video saving complete.")
    # --- End Release Video Writer ---

    # Convert lists to numpy arrays
    state_across_time = np.array(state_across_time)
    gt_action_across_time = np.array(gt_action_across_time)
    pred_action_across_time = np.array(pred_action_across_time)

    # --- Shape Assertion and NaN Handling ---
    if not (state_across_time.shape[0] == gt_action_across_time.shape[0] == pred_action_across_time.shape[0]):
         print(f"Shape mismatch warning before assertion:")
         print(f"State: {state_across_time.shape}")
         print(f"GT Action: {gt_action_across_time.shape}")
         print(f"Pred Action: {pred_action_across_time.shape}")
         # Attempt to truncate (less likely needed now)
         min_len = min(state_across_time.shape[0], gt_action_across_time.shape[0], pred_action_across_time.shape[0])
         state_across_time = state_across_time[:min_len]
         gt_action_across_time = gt_action_across_time[:min_len]
         pred_action_across_time = pred_action_across_time[:min_len]

    # Check lengths again after potential truncation
    if not (state_across_time.shape[0] == gt_action_across_time.shape[0] == pred_action_across_time.shape[0]):
         raise AssertionError(f"Mismatched lengths after processing: {state_across_time.shape[0]} vs {gt_action_across_time.shape[0]} vs {pred_action_across_time.shape[0]}")

    # Check action dimension consistency (only if we have data)
    if state_across_time.shape[0] > 0:
        action_dim = dataset.action_dim # Assuming dataset has this property
        if not (gt_action_across_time.shape[1] == pred_action_across_time.shape[1] == action_dim):
             raise AssertionError(f"Mismatched action dimensions: GT={gt_action_across_time.shape[1]}, Pred={pred_action_across_time.shape[1]}, Expected={action_dim}")
        # State dim check (optional, as state is currently zeros)
        # state_dim = dataset.state_dim
        # assert state_across_time.shape[1] == state_dim, f"Mismatched state dimensions: State={state_across_time.shape[1]}, Expected={state_dim}"

    # Calculate MSE, ignoring NaNs if any were introduced
    squared_error = (gt_action_across_time - pred_action_across_time) ** 2
    mse_per_dim = np.nanmean(squared_error, axis=0)
    overall_mse = np.nanmean(mse_per_dim) # Mean over dimensions

    print(f"Unnormalized Action MSE across episode {traj_id} (NaNs ignored): {overall_mse:.6f}")
    # print(f"MSE per dimension: {mse_per_dim}") # Optional: print per-dim MSE

    if plot and state_across_time.shape[0] > 0:
        num_action_dims = gt_action_across_time.shape[1]
        # Decide how many dimensions to plot or group them
        # Plotting all 14 might be too much, let's plot first few or specific ones
        dims_to_plot = num_action_dims

        fig, axes = plt.subplots(nrows=dims_to_plot, ncols=1, figsize=(12, 2 * dims_to_plot), sharex=True)
        if dims_to_plot == 1: # Handle case with single plot
            axes = [axes]

        fig.suptitle(
            f"Episode {traj_id} ({ep_name}) - Action Trajectories (Dims 0-{dims_to_plot-1})\nInstruction: {instruction}\nMSE: {overall_mse:.4f}",
            fontsize=14,
        )

        time_steps = np.arange(effective_steps)

        for i in range(dims_to_plot):
            ax = axes[i]
            # Plot GT Action
            ax.plot(time_steps, gt_action_across_time[:, i], label="GT Action", alpha=0.8, color='blue')

            # Plot Predicted Action (handling NaNs)
            pred_series = pd.Series(pred_action_across_time[:, i], index=time_steps)
            ax.plot(pred_series.dropna(), label="Pred Action", linestyle='--', alpha=0.8, color='red')


            # Mark inference points on the GT line for context
            inference_indices = np.arange(0, effective_steps, action_horizon)
            ax.plot(time_steps[inference_indices], gt_action_across_time[inference_indices, i],
                    "o", markersize=4, color='green', alpha=0.6, label="Inference Start" if i == 0 else None)

            ax.set_title(f"Action Dimension {i}")
            ax.grid(True, linestyle=':')
            if i == dims_to_plot - 1: # Add legend to the last plot
                 ax.legend(loc='best') # Changed loc to 'best'
                 ax.set_xlabel("Time Step in Episode")
            ax.set_ylabel("Value")


        plt.tight_layout(rect=[0, 0.03, 1, 0.93]) # Adjust layout to prevent title overlap
        # save plot to file
        plot_filename = f"eval_episode_{traj_id}_{ep_name.replace('.zip','')}_plot.png"
        os.makedirs("eval_plots", exist_ok=True)
        plot_path = os.path.join("eval_plots", plot_filename)
        plt.savefig(plot_path)
        print(f"Plot saved to {plot_path}")
        # plt.show() # Optionally display plot interactively
        plt.close(fig) # Close plot to free memory


    return overall_mse 