import os
import pathlib
import pickle

import imageio
import numpy as np
from absl import app, flags
from PIL import Image
from tqdm import tqdm, trange

from bpref_v2.data.qlearning_dataset import qlearning_factorworld_dataset
from bpref_v2.envs import MetaWorld
from bpref_v2.reward_learning.preference_dataset import PrefDataset

# from bpref_v2.reward_learning.reward_transform import load_queries_with_indices

FLAGS = flags.FLAGS

flags.DEFINE_string("env_name", "factorworld-pick-place-v2", "Environment name.")
flags.DEFINE_string("save_dir", "./video/", "saving dir.")
flags.DEFINE_string("query_path", "./human_label/", "query path")
flags.DEFINE_string("dataset_path", "/home/ICLR2024", "dataset path")
flags.DEFINE_integer("num_query", 1000, "number of query.")
flags.DEFINE_integer("query_len", 100, "length of each query.")
flags.DEFINE_integer("label_type", 0, "label type.")
flags.DEFINE_integer("seed", 3407, "seed for reproducibility.")
flags.DEFINE_string("video_type", "gif", "video_type")


def save_video(imgs, save_path):
    imgs = [Image.fromarray(img) for img in imgs]
    # duration is the number of milliseconds between frames; this is 40 frames per second
    imgs[0].save(save_path, save_all=True, append_images=imgs[1:], duration=50, loop=0)


def visualize_query(env_name, dataset, query_len, num_query, save_dir="./video"):
    save_dir = os.path.join(save_dir, env_name)
    os.makedirs(save_dir, exist_ok=True)

    for seg_idx in trange(num_query):
        video = np.concatenate((np.array(dataset[seg_idx]["images"]), np.array(dataset[seg_idx]["images_2"])), axis=2)
        if FLAGS.video_type == "gif":
            save_video(video, os.path.join(save_dir, f"./idx{seg_idx}.gif"))
        elif FLAGS.video_type == "mp4":
            writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
            writer = imageio.get_writer(os.path.join(save_dir, f"./idx{seg_idx}.mp4"), fps=30)
            for frame in tqdm(video, leave=False):
                writer.append_data(frame)
            writer.close()

    print("save query indices.")
    with open(
        os.path.join(save_dir, f"human_indices_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl"),
        "wb",
    ) as f:
        pickle.dump(dataset.query_1[seg_idx], f)
    with open(
        os.path.join(
            save_dir,
            f"human_indices_2_numq{num_query}_len{query_len}_s{FLAGS.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(dataset.query_2[seg_idx], f)


def main(_):
    env_name = "-".join(FLAGS.env_name.split("-")[1:])
    gym_env = MetaWorld(env_name, seed=FLAGS.seed)
    ds = qlearning_factorworld_dataset(FLAGS.dataset_path)

    query_path = pathlib.Path(FLAGS.query_path).expanduser()
    if os.path.exists(query_path):
        indices_1_file = query_path / f"indices_num{FLAGS.num_query}_q{FLAGS.query_len}"
        indices_2_file = query_path / f"indices_2_num{FLAGS.num_query}_q{FLAGS.query_len}"
        label_file = query_path / f"label_scripted_num{FLAGS.num_query}_q{FLAGS.query_len}"
        with indices_1_file.open("rb") as fp, indices_2_file.open("rb") as gp, label_file.open("rb") as hp:
            query_1 = pickle.load(fp)
            query_2 = pickle.load(gp)
            label = pickle.load(hp)
    else:
        raise "NO such query sets."

    data = PrefDataset.get_default_config()
    data.num_query = FLAGS.num_query
    data.query_len = FLAGS.query_len
    data.label_type = FLAGS.label_type
    data.use_image = True

    dataset = PrefDataset(update=data, env=gym_env, ds=ds, query_1=query_1, query_2=query_2, label=label)

    # batch = load_queries_with_indices(
    #     gym_env,
    #     ds,
    #     saved_indices=[query_1, query_2],
    #     saved_labels=label,
    #     num_query=FLAGS.num_query,
    #     len_query=FLAGS.query_len,
    #     label_type=FLAGS.label_type,
    #     scripted_teacher=True,
    # )
    visualize_query(FLAGS.env_name, dataset, FLAGS.query_len, FLAGS.num_query, save_dir=FLAGS.save_dir)


if __name__ == "__main__":
    app.run(main)
