import os
import h5py
import yaml
import numpy as np
import argparse
import cv2
import torch
import hydra
from torchvision import transforms as T
from omegaconf import OmegaConf
from easydict import EasyDict
from libero.libero import get_libero_path, benchmark
from libero.libero.benchmark import get_benchmark
from libero.libero.envs import DemoRenderEnv, OffScreenRenderEnv
from libero.lifelong.utils import control_seed

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

def parse_args():
    parser = argparse.ArgumentParser(description="Update Dataset Script")
    parser.add_argument("--benchmark", type=str, required=True,
                        choices=["libero_10", "libero_spatial", "libero_object", "libero_goal"], help="Benchmark name")
    args = parser.parse_args()
    return args

def setup_env(task_bddl_file, seed):
    env_args = {
        "bddl_file_name": task_bddl_file,
        "camera_heights": 128,
        "camera_widths": 128,
        "has_renderer": True,
        "has_offscreen_renderer": True,
        "use_camera_obs": True,
        "camera_names": ["agentview", "robot0_eye_in_hand"]  # Include both cameras
    }
    env = OffScreenRenderEnv(**env_args)
    env.seed(seed)
    return env

def copy_attrs(src, dst):
    """Copy attributes from one HDF5 object to another."""
    for key, value in src.attrs.items():
        dst.attrs[key] = value

def update_images_in_hdf5(src_file, dst_file, benchmark, task_id, cfg):
    # Open the source HDF5 file and create a new one for updated data
    with h5py.File(src_file, 'r') as src, h5py.File(dst_file, 'w') as dst:
        dst.create_group(f"data")
        copy_attrs(src['data'], dst['data'])

        # Initialize benchmark and task
        task = benchmark.get_task(task_id)
        task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

        # Set up environment for rendering
        env = setup_env(task_bddl_file, cfg.seed)

        # Get demo names and sort them numerically
        demo_names = sorted(src['data'], key=lambda x: int(x.split('_')[-1]))

        valid_demos = []  # List to store valid demos for later renaming

        for i, demo_name in enumerate(demo_names):
            print(f"Processing: {demo_name}")
            env.reset()

            # Regenerate images and observations from simulation
            demo_group = src['data'][demo_name]
            actions = demo_group['actions'][:]
            initial_state = demo_group['states'][0]
            env.sim.set_state_from_flattened(initial_state)

            agentview_rgbs = []
            eye_in_hand_rgbs = []
            gripper_states_list = []
            joint_states_list = []
            done = False

            for _ in range(5):
                env.step(np.zeros(7))

            for action in actions:
                obs, _, done, _ = env.step(action)
                # img = np.flipud(obs['agentview_image'])
                # cv2.imshow("Agentview", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
                # cv2.waitKey(1)

                agentview_rgb = obs['agentview_image']
                eye_in_hand_rgb = obs['robot0_eye_in_hand_image']
                gripper_states = obs['robot0_gripper_qpos']
                joint_states = obs['robot0_joint_pos']

                agentview_rgbs.append(agentview_rgb)
                eye_in_hand_rgbs.append(eye_in_hand_rgb)
                gripper_states_list.append(gripper_states)
                joint_states_list.append(joint_states)

            if not done:
                print(f"Demo {demo_name} has failed and will be removed.")
                continue

            # If the demo is successful, proceed to copy the demo data
            new_demo_group = dst.create_group(f"data/demo_{len(valid_demos)}")  # Sequential renaming
            copy_attrs(demo_group, new_demo_group)

            for key in demo_group:
                if key != "obs":
                    demo_group.copy(key, new_demo_group)
                    copy_attrs(demo_group[key], new_demo_group[key])

            obs_group = demo_group['obs']
            new_obs_group = new_demo_group.create_group("obs")

            # Save the recomputed images and observations
            agentview_rgbs_np = np.array(agentview_rgbs)
            eye_in_hand_rgbs_np = np.array(eye_in_hand_rgbs)
            gripper_states_np = np.array(gripper_states_list)
            joint_states_np = np.array(joint_states_list)

            new_obs_group.create_dataset("agentview_rgb", data=agentview_rgbs_np)
            new_obs_group.create_dataset("eye_in_hand_rgb", data=eye_in_hand_rgbs_np)
            new_obs_group.create_dataset("gripper_states", data=gripper_states_np)
            new_obs_group.create_dataset("joint_states", data=joint_states_np)

            for obs_key in obs_group:
                # note that other states are directly copied from the original dataset but not from the
                # observations from the new simulation
                if obs_key not in ['agentview_rgb', 'eye_in_hand_rgb', 'gripper_states', 'joint_states']:
                    demo_group['obs'].copy(obs_key, new_obs_group)
                copy_attrs(obs_group[obs_key], new_obs_group[obs_key])

            valid_demos.append(demo_name)  # Add the valid demo name for later reference

        # # Pad the dataset with initial demos if there are less than 50 valid ones
        # num_valid_demos = len(valid_demos)
        # if num_valid_demos < 50:
        #     print(f"Padding dataset. Current valid demos: {num_valid_demos}")
        #     for pad_idx in range(50 - num_valid_demos):
        #         src_demo_group = dst[f"data/demo_{pad_idx % num_valid_demos}"]
        #         dst_demo_group = dst.create_group(f"data/demo_{num_valid_demos + pad_idx}")
        #         copy_attrs(src_demo_group, dst_demo_group)
        #
        #         for key in src_demo_group:
        #             src_demo_group.copy(key, dst_demo_group)
        #             copy_attrs(src_demo_group[key], dst_demo_group[key])

        env.close()


@hydra.main(config_path="../configs", config_name="config", version_base=None)
def main(hydra_cfg):
    # Configuration loading
    yaml_config = OmegaConf.to_yaml(hydra_cfg)
    cfg = EasyDict(yaml.safe_load(yaml_config))

    # Control seed for reproducibility
    control_seed(cfg.seed)

    # Prepare folders and paths
    cfg.folder = cfg.folder or get_libero_path("datasets")
    benchmark_name = cfg.benchmark_name
    benchmark = get_benchmark(benchmark_name)(cfg.data.task_order_index)
    n_manip_tasks = benchmark.n_tasks

    # Iterate over each task in the benchmark
    for task_id in range(n_manip_tasks):
        task_demo_path = os.path.join(cfg.folder, benchmark.get_task_demonstration(task_id))
        new_task_demo_path = os.path.join(cfg.folder, benchmark.get_task_demonstration_new_rendering(task_id))

        os.makedirs(os.path.dirname(new_task_demo_path), exist_ok=True)

        print(f"Processing dataset for task {task_id} at {task_demo_path}")

        # Update HDF5 dataset with new images
        update_images_in_hdf5(task_demo_path, new_task_demo_path, benchmark, task_id, cfg)
        print(f"Updated dataset saved at {new_task_demo_path}")

if __name__ == "__main__":
    main()
