import sys

sys.path.append("./")

import os
import h5py
import numpy as np
import pickle
import cv2
import argparse
import yaml
from scripts.encode_lang_batch_once import encode_lang


def load_hdf5(dataset_path):
    if not os.path.isfile(dataset_path):
        print(f"Dataset does not exist at \n{dataset_path}\n")
        exit()

    with h5py.File(dataset_path, "r") as root:
        left_gripper, left_arm = (
            root["/joint_action/left_gripper"][()],
            root["/joint_action/left_arm"][()],
        )
        right_gripper, right_arm = (
            root["/joint_action/right_gripper"][()],
            root["/joint_action/right_arm"][()],
        )
        image_dict = dict()
        for cam_name in root[f"/observation/"].keys():
            image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()]

    return left_gripper, left_arm, right_gripper, right_arm, image_dict


def images_encoding(imgs):
    encode_data = []
    padded_data = []
    max_len = 0
    for i in range(len(imgs)):
        success, encoded_image = cv2.imencode(".jpg", imgs[i])
        jpeg_data = encoded_image.tobytes()
        encode_data.append(jpeg_data)
        max_len = max(max_len, len(jpeg_data))
    # padding
    for i in range(len(imgs)):
        padded_data.append(encode_data[i].ljust(max_len, b"\0"))
    return encode_data, max_len


def get_task_config(task_name):
    with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f:
        args = yaml.load(f.read(), Loader=yaml.FullLoader)
    return args


def data_transform(path, episode_num, save_path):
    begin = 0
    floders = os.listdir(path)
    assert episode_num <= len(floders), "data num not enough"

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for i in range(episode_num):
        left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5(
            os.path.join(path, f"episode{i}.hdf5")))
        qpos = []
        actions = []
        cam_high = []
        cam_right_wrist = []
        cam_left_wrist = []
        left_arm_dim = []
        right_arm_dim = []

        last_state = None
        for j in range(0, left_gripper_all.shape[0]):

            left_gripper, left_arm, right_gripper, right_arm = (
                left_gripper_all[j],
                left_arm_all[j],
                right_gripper_all[j],
                right_arm_all[j],
            )

            state = np.concatenate((left_arm, [left_gripper], right_arm, [right_gripper]), axis=0)  # joint
            state = state.astype(np.float32)

            if j != left_gripper_all.shape[0] - 1:

                qpos.append(state)

                camera_high_bits = image_dict["head_camera"][j]
                camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR)
                camera_high_resized = cv2.resize(camera_high, (640, 480))
                cam_high.append(camera_high_resized)

                camera_right_wrist_bits = image_dict["right_camera"][j]
                camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
                camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480))
                cam_right_wrist.append(camera_right_wrist_resized)

                camera_left_wrist_bits = image_dict["left_camera"][j]
                camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
                camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480))
                cam_left_wrist.append(camera_left_wrist_resized)

            if j != 0:
                action = state
                actions.append(action)
                left_arm_dim.append(left_arm.shape[0])
                right_arm_dim.append(right_arm.shape[0])

        if not os.path.exists(os.path.join(save_path, f"episode_{i}")):
            os.makedirs(os.path.join(save_path, f"episode_{i}"))
        hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5")

        with h5py.File(hdf5path, "w") as f:
            f.create_dataset("action", data=np.array(actions))
            obs = f.create_group("observations")
            obs.create_dataset("qpos", data=np.array(qpos))
            obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
            obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
            image = obs.create_group("images")
            cam_high_enc, len_high = images_encoding(cam_high)
            cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
            cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist)
            image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}")
            image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}")
            image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}")

        begin += 1
        print(f"proccess {i} success!")

    return begin


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some episodes.")
    parser.add_argument("task_name", type=str)
    parser.add_argument("task_config", type=str)
    parser.add_argument("expert_data_num", type=int)
    args = parser.parse_args()

    task_name = args.task_name
    task_config = args.task_config
    expert_data_num = args.expert_data_num

    DATA_ROOT = os.environ.get(
        "ROBOTWIN_DATA",
        os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data")),
    )
    load_dir = os.path.join(DATA_ROOT, str(task_name), str(task_config), "data")
    print(f"[INFO] Read data from path: {load_dir}")

    #load_dir = os.path.join("../../data", str(task_name), str(task_config), "data")

    #print(f"read data from path: {load_dir}")


    if not os.path.exists(load_dir):
        alt_dir = os.path.join(DATA_ROOT, str(task_name), str(task_config), str(task_config), "data")
        if os.path.exists(alt_dir):
            print(f"[WARN] Path not found, using alternative: {alt_dir}")
            load_dir = alt_dir
    begin = data_transform(
        load_dir,
        expert_data_num,
        f"./processed_data/{task_name}-{task_config}-{expert_data_num}",
    )
    tokenizer, text_encoder = None, None
    for idx in range(expert_data_num):
        print(f"Processing Language: {idx}", end="\r")
        data_file_path = (f"../../data/{task_name}/{task_config}/instructions/episode{idx}.json")

        if not os.path.exists(data_file_path):
            alt_path = os.path.join(DATA_ROOT, task_name, task_config, task_config, "instructions", f"episode{idx}.json")
            if os.path.exists(alt_path):
                print(f"[WARN] Using alternative instruction path: {alt_path}")
                data_file_path = alt_path
            else:
                print(f"[ERROR] Cannot find instruction file for episode {idx}")
                continue

        target_dir = (f"processed_data/{task_name}-{task_config}-{expert_data_num}/episode_{idx}")
        tokenizer, text_encoder = encode_lang(
            DATA_FILE_PATH=data_file_path,
            TARGET_DIR=target_dir,
            GPU=0,
            desc_type="seen",
            tokenizer=tokenizer,
            text_encoder=text_encoder,
        )
