from gymnasium import spaces
import gymnasium as gym
import mani_skill
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
import os
from datasets import load_dataset
import matplotlib.pyplot as plt
import os
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from datasets import load_dataset
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy

from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

from moviepy import ImageSequenceClip
from IPython.display import HTML
from base64 import b64encode
import tempfile
import json
import numpy as np
from torch.utils.data import DataLoader

from mikasa_utils import get_mikasa_eval_env
from collections import deque
from dataclasses import dataclass

@dataclass
class Args:
    env_img_res: int = 256                           # Resolution for environment images (not policy input resolution)
    exp_name: Optional[str] = None
    env_id: str = "RememberColor3-v0" #'RememberColor3-v0' #"ShellGamePush-v0"
    language_instruction: str = ""
    """The environment ID of the task you want to simulate."""
    shader: str = "default"
    num_episodes: int = 100
    """Number of episodes to run and record evaluation metrics over"""
    record_dir: str = "videos"
    """The directory to save videos and results"""
    model: Optional[str] = "octo-base"
    """The model to evaluate on the given environment. Can be one of octo-base, octo-small, rt-1x. If not given, random actions are sampled."""
    ckpt_path: str = "" 
    """Checkpoint path for models. Only used for RT models"""
    seed: int = 0
    """Seed the model and environment. Default seed is 0"""
    reset_by_episode_id: bool = True
    """Whether to reset by fixed episode ids instead of random sampling initial states."""
    info_on_video: bool = False
    """Whether to write info text onto the video"""
    save_video: bool = True
    """Whether to save videos"""
    device: str = 'cuda:0'
    camera_width: Optional[int] = 128
    """the width of the camera image. If none it will use the default the environment specifies"""
    camera_height: Optional[int] = 128
    """the height of the camera image. If none it will use the default the environment specifies."""
    include_oracle: bool = False
    """if toggled, oracle info (such as cup_with_ball_number in ShellGamePush-v0) will be used during the training, i.e. reducing memory task to MDP"""
    noop_steps: int = 1
    """if = 1, then no noops, if > 1, then noops for t ~ [0, noop_steps-1]"""
    include_rgb: bool = True
    """if toggled, rgb images will be included in the observation space"""
    include_joints: bool = False
    """[works only with include_rgb=True] if toggled, joints will be included in the observation space"""
    reward_mode: str = 'normalized_dense' # sparse | normalized_dense
    """the mode of the reward function"""
    control_mode: Optional[str] = "pd_ee_delta_pose"
    """the control mode to use for the environment"""
    render_mode: str = "all"
    """the environment rendering mode"""
    """the id of the environment"""
    include_state: bool = False
    """whether to include state information in observations"""
    num_eval_steps: int = 60
    num_eval_episodes: int = 100
    sim_backend: str = 'gpu'


args = Args()

seed = 42
env = get_mikasa_eval_env(args)

POLICY_PATH = '.../pretrained_model' 
DATASET_PATH = "..."
ACTIONS_PER_CHUNK = 4

dataset = LeRobotDataset(DATASET_PATH, video_backend="pyav")
policy = PI0Policy.from_pretrained(POLICY_PATH).to("cuda:0")

def inject_normalization_stats(policy, dataset):
    stats = dataset.meta.stats
    pol_state_dict = policy.state_dict()

    print("Available stats keys:", list(stats.keys()))

    keys_to_update = {
        "normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"),
        "normalize_inputs.buffer_observation_state.std": ("observation.state", "std"),
        "normalize_targets.buffer_action.mean": ("action", "mean"),
        "normalize_targets.buffer_action.std": ("action", "std"),
        "unnormalize_outputs.buffer_action.mean": ("action", "mean"),
        "unnormalize_outputs.buffer_action.std": ("action", "std"),
    }

    updated_count = 0
    for pol_key, (stat_key, stat_type) in keys_to_update.items():
        if pol_key in pol_state_dict and stat_key in stats:
            pol_state_dict[pol_key] = torch.from_numpy(stats[stat_key][stat_type])
            updated_count += 1
        else:
            print(f"Could not find {pol_key} or {stat_key}")

    policy.load_state_dict(pol_state_dict)
    print("Normalization stats injected into the policy.")

def get_action_chunk(policy, batch, device, actions_per_chunk=10):
    with torch.no_grad():
        batch_processed = prepare_observation_for_policy(batch, device)
        batch_normalized = policy.normalize_inputs(batch_processed)
        print(batch_normalized)
        
        images, img_masks = policy.prepare_images(batch_normalized)
        state = policy.prepare_state(batch_normalized)
        lang_tokens, lang_masks = policy.prepare_language(batch_normalized)
        
        
        actions = policy.model.sample_actions(
            images, img_masks, lang_tokens, lang_masks, state
        )
        
        original_action_dim = policy.config.action_feature.shape[0]
        actions = actions[:, :actions_per_chunk, :original_action_dim]
        
        return actions
    

inject_normalization_stats(policy, dataset)
policy.eval()

obs, _ = env.reset(seed = 0)
obs_list = []
actions = get_action_chunk(policy, obs, "cuda:0", ACTIONS_PER_CHUNK)

for i in range(args.num_eval_steps):
    obs_list.append(env.render().detach().cpu().clone())
    obs, reward, terminated, truncated, info = env.step(0.1*actions[:,i,:])
