# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib.pyplot as plt
import numpy as np
import torch  # Import torch
import pandas as pd  # Import pandas for DataFrame operations
import cv2 # Import OpenCV

from gr00t.data.dvrk_dataset import EpisodicDatasetDvrkGeneric
from gr00t.model.policy import BasePolicy
from gr00t.data.transform.base import ComposedModalityTransform
# numpy print precision settings 3, dont use exponential notation
np.set_printoptions(precision=3, suppress=True)


def download_from_hg(repo_id: str, repo_type: str) -> str:
    """
    Download the model/dataset from the hugging face hub.
    return the path to the downloaded
    """
    from huggingface_hub import snapshot_download

    repo_path = snapshot_download(repo_id, repo_type=repo_type)
    return repo_path


# Precompute start indices for faster lookup
def _get_episode_start_indices(episode_lengths):
    start_indices = [0] * len(episode_lengths)
    current_start = 0
    for i, length in enumerate(episode_lengths):
        start_indices[i] = current_start
        current_start += length
    return start_indices


def calc_mse_for_single_trajectory(
    policy: BasePolicy,
    dataset: EpisodicDatasetDvrkGeneric,
    traj_id: int,
    modality_keys: list,
    steps=300,
    action_horizon=16,
    plot=False,
    save_video=True, # Add flag to control video saving
    video_fps=10,    # Add parameter for video FPS
):
    """
    Calculates the Mean Squared Error (MSE) for actions over a single episode
    from the EpisodicDatasetDvrkGeneric dataset and optionally saves a video.

    Args:
        policy: The policy model to evaluate.
        dataset: An instance of EpisodicDatasetDvrkGeneric.
        traj_id: The index of the episode to evaluate within the dataset.
        modality_keys: List of sub-keys (e.g., "psm1", "psm2") for state and action
                       modalities to concatenate and evaluate.
        steps: The maximum number of steps to evaluate within the episode.
        action_horizon: The number of steps the policy predicts at once, and the interval
                        at which inference is performed.
        plot: Whether to plot the state and action trajectories.
        save_video: Whether to save a video of the 'video.main' frames.
        video_fps: Frames per second for the saved video.
    """
    gt_action_joints_across_time = []
    pred_action_joints_across_time = []
    action_chunk = None # To store the latest predicted action chunk
    video_writer = None # Initialize video writer

    # --- Get episode information ---
    if traj_id < 0 or traj_id >= len(dataset.episode_list):
        raise ValueError(f"traj_id {traj_id} is out of bounds for dataset with {len(dataset.episode_list)} episodes.")

    episode_lengths = dataset.episode_lengths # Assuming this exists
    episode_start_indices = _get_episode_start_indices(episode_lengths)

    episode_start_flat_idx = episode_start_indices[traj_id]
    episode_len = episode_lengths[traj_id]
    # Adjust steps if the requested number exceeds episode length
    effective_steps = min(steps, episode_len)
    print(f"Evaluating episode {traj_id} (Length: {episode_len}). Running for {effective_steps} steps.")
    # --- End Get episode information ---


    for step_in_episode in range(effective_steps):
        current_flat_index = episode_start_flat_idx + step_in_episode
        # Fetch data using the dataset's __getitem__
        # Ensure transforms return tensors, convert to numpy if policy expects numpy
        with torch.no_grad(): # Disable gradient calculation if using PyTorch tensors
             # Assuming transforms might return torch tensors
            data_point_torch = dataset._retrieve_data(current_flat_index)

            # --- Video Frame Processing ---
            if save_video:
                main_img_tensor = data_point_torch.get("video.main") # Use .get for safety
                if main_img_tensor is not None:
                    # Assuming tensor is (C, H, W) or (1, C, H, W)
                    if main_img_tensor.ndim == 4:
                        main_img_tensor = main_img_tensor[0] # Remove batch dim if present

                    # Convert to numpy HWC, uint8, BGR for OpenCV
                    frame_hwc = main_img_tensor
                    # frame_hwc = np.transpose(frame_chw, (1, 2, 0)) # C, H, W -> H, W, C

                    # Scale if necessary (assuming float 0-1) and convert type
                    if frame_hwc.dtype == np.float32 or frame_hwc.dtype == np.float64:
                         frame_hwc = (frame_hwc * 255).astype(np.uint8)
                    elif frame_hwc.dtype != np.uint8: # Handle other potential types if needed
                         frame_hwc = frame_hwc.astype(np.uint8)


                    # Convert RGB to BGR
                    frame_bgr = cv2.cvtColor(frame_hwc, cv2.COLOR_RGB2BGR)

                    # Initialize writer on first frame
                    if video_writer is None:
                        h, w = frame_bgr.shape[:2]
                        video_filename = f"eval_episode_{traj_id}_main_video.mp4"
                        fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec
                        video_writer = cv2.VideoWriter(video_filename, fourcc, video_fps, (w, h))
                        print(f"Saving video to {video_filename} ({w}x{h} @ {video_fps} FPS)")

                    video_writer.write(frame_bgr)
                else:
                    # Handle case where 'video.main' is missing if needed
                    if step_in_episode == 0: # Print warning only once
                         print("Warning: 'video.main' not found in data point. Cannot save video.")
                         save_video = False # Disable saving for the rest of the episode

            # --- End Video Frame Processing ---

            # Convert relevant parts to numpy if policy expects numpy
            data_point = {}
            for key, value in data_point_torch.items():
                if isinstance(value, torch.Tensor):
                    data_point[key] = value.float().cpu().numpy()
                else:
                    data_point[key] = value
        

        # Concatenate ground truth action modalities for the *current* step
        # Action from dataset is (action_horizon, dim), take [0] for the current step's GT
        # concat_gt_action = np.concatenate(
        #     [data_point[f"action.{key}"] for key in modality_keys], axis=0
        # )
        # gt_action_joints_across_time.append(concat_gt_action)

        # Perform inference periodically
        if step_in_episode % action_horizon == 0:
            print(f"Inferencing at step: {step_in_episode} (Flat Index: {current_flat_index})")
            # Policy might expect torch tensors or numpy, adjust if needed
            # Pass the *original* data_point_torch if policy expects tensors
            action_chunk = policy.get_action(data_point_torch) # Or data_point if policy uses numpy
            # Convert action_chunk items to numpy if they are tensors
            for key, value in action_chunk.items():
                 if isinstance(value, torch.Tensor):
                     action_chunk[key] = value.cpu().numpy()

            # Concatenate ground truth action modalities for the *current* step
            # Action from dataset is (action_horizon, dim), take [0] for the current step's GT
            concat_gt_action = np.concatenate(
                [data_point[f"action.{key}"] for key in modality_keys], axis=1
            )

        # Get the predicted action for the current step from the last chunk
        if action_chunk is not None:
            pred_step_in_chunk = step_in_episode % action_horizon
            # Ensure the prediction step is within the bounds of the chunk
            # Check if the action key exists and has sufficient length
            first_action_key = f"action.{modality_keys[0]}"
            if first_action_key in action_chunk and pred_step_in_chunk < len(action_chunk[first_action_key]):
                concat_pred_action = np.concatenate(
                    [np.atleast_1d(action_chunk[f"action.{key}"][pred_step_in_chunk]) for key in modality_keys],
                    axis=0,
                )
                denormalized_pred_action = dataset.denormalize_action(concat_pred_action)
                pred_action_joints_across_time.append(denormalized_pred_action)
                denormalized_gt_action = dataset.denormalize_action(concat_gt_action[pred_step_in_chunk])
                gt_action_joints_across_time.append(denormalized_gt_action)
            else:
                 # Handle cases where prediction is not available (e.g., end of episode before chunk is full)
                 # Append NaNs or zeros, or raise an error
                 print(f"Warning: Prediction not found for step {step_in_episode} (chunk index {pred_step_in_chunk}). Appending NaNs.")
                 # Determine the expected shape based on the first step's GT action
                 nan_action = np.full_like(concat_gt_action, np.nan)
                 pred_action_joints_across_time.append(nan_action)

        else:
             # Handle case before the first inference
             print(f"Warning: No action chunk available yet for step {step_in_episode}. Appending NaNs.")
             nan_action = np.full_like(concat_gt_action, np.nan)
             pred_action_joints_across_time.append(nan_action)


    # --- Release Video Writer ---
    if video_writer is not None:
        video_writer.release()
        print("Video saving complete.")
    # --- End Release Video Writer ---

    # Convert lists to numpy arrays
    gt_action_joints_across_time = np.array(gt_action_joints_across_time)
    # Ensure prediction list has the same length as GT list
    pred_action_joints_across_time = np.array(pred_action_joints_across_time)
    
    # --- Shape Assertion and NaN Handling ---
    if not (gt_action_joints_across_time.shape[0] == pred_action_joints_across_time.shape[0]):
         print(f"Shape mismatch warning before assertion:")
         print(f"GT Action: {gt_action_joints_across_time.shape}")
         print(f"Pred Action: {pred_action_joints_across_time.shape}")
         # Attempt to truncate if prediction is longer (less likely now)
         min_len = min(gt_action_joints_across_time.shape[0], pred_action_joints_across_time.shape[0])
         gt_action_joints_across_time = gt_action_joints_across_time[:min_len]
         pred_action_joints_across_time = pred_action_joints_across_time[:min_len]

    assert gt_action_joints_across_time.shape[0] == pred_action_joints_across_time.shape[0], f"Mismatched lengths after processing: {gt_action_joints_across_time.shape[0]} != {pred_action_joints_across_time.shape[0]}"
    assert gt_action_joints_across_time.shape[1] == pred_action_joints_across_time.shape[1], f"Mismatched dimensions (joints): {gt_action_joints_across_time.shape[1]} != {pred_action_joints_across_time.shape[1]}"

    # Calculate MSE, ignoring NaNs if any were introduced
    mse = np.nanmean((gt_action_joints_across_time - pred_action_joints_across_time) ** 2)
    print(f"Unnormalized Action MSE across episode {traj_id} (NaNs ignored): {mse:.6f}")

    num_of_joints = gt_action_joints_across_time.shape[1]

    if plot:
        fig, axes = plt.subplots(nrows=num_of_joints, ncols=1, figsize=(10, 3 * num_of_joints), sharex=True)
        if num_of_joints == 1: # Handle case with single joint/dimension
            axes = [axes]

        # Add a global title showing the episode ID and modality keys
        episode_path, instruction, _ = dataset.episode_list[traj_id]
        ep_name = episode_path.split('/')[-1]
        fig.suptitle(
            f"Episode {traj_id} ({ep_name})\nInstruction: {instruction}\n Modalities: {', '.join(modality_keys)}",
            fontsize=14,
            color="blue",
        )

        time_steps = np.arange(effective_steps)

        for i, ax in enumerate(axes):
            ax.plot(time_steps, gt_action_joints_across_time[:, i], label="GT Action (Step 0)", alpha=0.7)
            # Plot predictions, handling NaNs so they appear as gaps
            pred_series = pd.Series(pred_action_joints_across_time[:, i], index=time_steps)
            ax.plot(pred_series.dropna(), label="Pred Action (Step j)", linestyle='--')


            # Mark inference points on the GT line
            inference_indices = range(0, effective_steps, action_horizon)
            ax.plot(time_steps[inference_indices], gt_action_joints_across_time[inference_indices, i],
                    "ro", markersize=5, label="Inference Point" if i == 0 else None)

            ax.set_title(f"Dimension {i}")
            ax.grid(True, linestyle=':')
            if i == num_of_joints -1 : # Only add legend to the last plot
                 ax.legend(loc='upper right')
            if i == num_of_joints - 1:
                ax.set_xlabel("Time Step in Episode")


        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
        # save plot to file
        plot_filename = f"eval_episode_{traj_id}_plot.png" # Changed filename slightly
        plt.savefig(plot_filename)
        print(f"Plot saved to {plot_filename}")
        plt.show()


    return mse