import io
import os
import pickle
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import torchvision.transforms as T
from absl import app, flags
from ml_collections import ConfigDict
from PIL import Image
from tqdm import trange

from bpref_v2.data.arp_furniturebench_dataset_inmemory_stream import (
    get_failure_skills_and_phases,
)
from bpref_v2.data.instruct import TASK_TO_PHASE
from bpref_v2.utils.reward_model_loader import load_reward_fn, load_reward_model

FLAGS = flags.FLAGS

flags.DEFINE_string("furniture", "one_leg", "Name of furniture.")
flags.DEFINE_string("demo_dir", "square_table_parts_state", "Demonstration dir.")
flags.DEFINE_string("out_dir", None, "Path to save converted data.")
flags.DEFINE_integer("num_success_demos", -1, "Number of demos to convert")
flags.DEFINE_integer("num_failure_demos", -1, "Number of demos to convert")
flags.DEFINE_string("target_indices", "", "target indices")
flags.DEFINE_integer("batch_size", 512, "Batch size for encoding images")
flags.DEFINE_string("ckpt_path", "", "ckpt path of reward model.")
flags.DEFINE_string("demo_type", "success", "type of demonstrations.")
flags.DEFINE_string("rm_type", "REDS", "reward model type.")
flags.DEFINE_string("pvr_type", "liv", "pvr type.")
flags.DEFINE_boolean("smoothe", False, "smooth reward or not.")
flags.DEFINE_boolean("predict_phase", False, "use phase prediction or not.")
flags.DEFINE_integer("window_size", 4, "window size")
flags.DEFINE_integer("skip_frame", 16, "skip frame")


device = torch.device("cuda")


def gaussian_smoothe(rewards, sigma=3.0):
    return scipy.ndimage.gaussian_filter1d(rewards, sigma=sigma, mode="nearest")


def exponential_moving_average(a, alpha=0.3):
    """
    Compute the Exponential Moving Average of a numpy array.

    :param a: Numpy array of values to compute the EMA for.
    :param alpha: Smoothing factor in the range [0,1].
                  The closer to 1, the more weight given to recent values.
    :return: Numpy array containing the EMA of the input array.
    """
    ema = np.zeros_like(a)  # Initialize EMA array with the same shape as input
    ema[0] = a[0]  # Set the first value of EMA to the first value of the input array

    # Compute EMA for each point after the first
    for i in range(1, len(a)):
        ema[i] = alpha * a[i] + (1 - alpha) * ema[i - 1]

    return ema


def save_episode(episode, fn):
    with io.BytesIO() as bs:
        np.savez_compressed(bs, **episode)
        bs.seek(0)
        with fn.open("wb") as f:
            f.write(bs.read())


def load_embedding(rep="vip"):
    if rep == "vip":
        from vip import load_vip

        model = load_vip()
        transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
        feature_dim = 1024
    if rep == "r3m":
        from r3m import load_r3m

        model = load_r3m("resnet50")
        transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
        feature_dim = 2048
    if rep == "liv":
        from liv import load_liv

        model = load_liv()
        transform = T.Compose([T.ToTensor()])
        feature_dim = 1024

    if rep.startswith("clip"):
        import clip

        if rep == "clip_vit_b16":
            model, transform = clip.load("ViT-B/16")
            feature_dim = 512
        if rep == "clip_vit_l14":
            model, transform = clip.load("ViT-L/14")
            feature_dim = 768

    model.eval()
    if rep in ["vip", "r3m", "liv"]:
        model = model.to(device)
    return model, transform, feature_dim


def main(_):
    demo_dir = Path(FLAGS.demo_dir)
    video_dir = demo_dir / list(Path(FLAGS.ckpt_path).parents)[1].name
    video_dir.mkdir(exist_ok=True)

    # load reward model.
    ckpt_path = Path(FLAGS.ckpt_path).expanduser()
    reward_model = load_reward_model(rm_type=FLAGS.rm_type, task_name=FLAGS.furniture, ckpt_path=ckpt_path)
    reward_fn = load_reward_fn(rm_type=FLAGS.rm_type, reward_model=reward_model)
    pvr_model, pvr_transform, feature_dim = load_embedding(rep=FLAGS.pvr_type)

    dir_path = Path(demo_dir)

    demo_type = [elem for elem in FLAGS.demo_type.split("|")]
    files = []
    for _demo_type in demo_type:
        print(f"Loading {_demo_type} demos...")
        demo_files = sorted(list(dir_path.glob(f"*_{_demo_type}.pkl")))
        if FLAGS.target_indices != "":
            target_indices = [int(elem) for elem in FLAGS.target_indices.split(",")]
            files.extend([(elem, demo_files[elem]) for elem in target_indices])
        else:
            len_demos = (
                getattr(FLAGS, f"num_{_demo_type}_demos")
                if getattr(FLAGS, f"num_{_demo_type}_demos") > 0
                else len(demo_files)
            )
            files.extend([(idx, path) for idx, path in enumerate(demo_files[:len_demos])])

    len_files = len(files)

    if len_files == 0:
        raise ValueError(f"No pkl files found in {dir_path}")

    for idx, file_path in files:
        print(f"Loading [{idx+1}/{len_files}] {file_path}...")
        # Create and save each frame
        temp_folder = demo_dir / f"temp_frames_{idx}"
        os.makedirs(temp_folder, exist_ok=True)

        with open(file_path, "rb") as f:
            x = pickle.load(f)
            tp = file_path.stem.split("_")[-1].split(".")[0]

            if len(x["observations"]) == len(x["actions"]):
                # Dummy
                x["observations"].append(x["observations"][-1])
            length = len(x["observations"])

            img1 = [x["observations"][_l]["color_image1"] for _l in range(length)]
            img2 = [x["observations"][_l]["color_image2"] for _l in range(length)]
            images = {key: val for key, val in [("color_image2", img2), ("color_image1", img1)]}
            for key, val in images.items():
                val = np.asarray(val).astype(np.uint8)
                val = np.transpose(val, (0, 2, 3, 1))
                images[key] = val

            if not FLAGS.predict_phase:
                skills = np.asarray(x["skills"])
                # actions, skills = x["actions"], np.cumsum(np.where(skills > 0.0, skills, 0.0))
                if "success" in file_path.name:
                    actions, skills = x["actions"], np.cumsum(np.where(skills > 0.0, skills, 0.0))
                else:
                    actions, phase = x["actions"], np.cumsum(np.where(skills > 0.0, skills, 0.0))
                    failure_phase = x.get("failure_phase", -1)
                    _, skills = get_failure_skills_and_phases(
                        skill=skills, phase=phase, task_name=FLAGS.furniture, failure_phase=failure_phase
                    )
            else:
                skills = np.asarray([0 for _ in range(length)])
                actions = np.asarray([0 for _ in range(length)])

            args = ConfigDict()
            args.task_name = FLAGS.furniture
            args.image_keys = "color_image2|color_image1"
            args.window_size = FLAGS.window_size
            args.skip_frame = FLAGS.skip_frame
            args.return_images = True

            output = reward_fn(
                images=images,
                actions=actions,
                skills=skills,
                args=args,
                pvr_model=pvr_model,
                pvr_transform=pvr_transform,
                model_type=FLAGS.pvr_type,
                feature_dim=feature_dim,
                texts=None,
                device=device,
                batch_size=FLAGS.batch_size,
                get_text_feature=True,
                predict_phase=FLAGS.predict_phase,
            )
            rewards = output["rewards"]
            if FLAGS.smoothe:
                # rewards = gaussian_smoothe(rewards)
                rewards = exponential_moving_average(rewards)
            if FLAGS.predict_phase:
                phases = output["phases"]
                processed_phases = output["processed_phases"]

                phases = np.where(phases == TASK_TO_PHASE[FLAGS.furniture], -1, phases)
                processed_phases = np.where(processed_phases == TASK_TO_PHASE[FLAGS.furniture], -1, processed_phases)

            # Visualize the image and the reward graph.

            for _idx in trange(len(images["color_image2"]), desc="Creating frames", leave=False, ncols=0):
                fig, ax = plt.subplots(1, 2, figsize=(6, 3))

                fig.suptitle(f"{tp}.{idx}")
                ax[0].imshow(images["color_image2"][_idx])
                ax[0].axis("off")
                ax[1].plot(rewards[: _idx + 1])
                if FLAGS.predict_phase:
                    ax1_2 = ax[1].twinx()
                    ax1_2.plot(phases[: _idx + 1], color="red", linestyle="--")
                    ax1_2.plot(processed_phases[: _idx + 1], color="green", linestyle="--")

                # Save the frame
                plt.savefig(f"{temp_folder}/frame_{_idx:04d}.png")
                plt.close(fig)

        # Create a video from these frames
        frame_files = sorted(
            [os.path.join(temp_folder, file) for file in os.listdir(temp_folder) if file.endswith(".png")]
        )
        first_frame = cv2.imread(frame_files[0])
        height, width, layers = first_frame.shape
        video_path = str(video_dir / f"video_{tp}_{idx}.mp4")
        video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height))

        imgs = []
        for frame_file in frame_files:
            frame_img = cv2.imread(frame_file)
            video.write(frame_img)
            imgs.append(frame_img)

        cv2.destroyAllWindows()
        video.release()

        def save_video(imgs, save_path):
            imgs = [Image.fromarray(img).resize((400, 200)) for img in imgs]
            imgs[0].save(save_path, save_all=True, append_images=imgs[1:], duration=30, loop=0)

        video_path = str(video_dir / f"video_{tp}_{idx}.gif")
        save_video(imgs, video_path)

        # Delete the temporary frames
        for frame_file in frame_files:
            os.remove(frame_file)
        os.rmdir(temp_folder)


if __name__ == "__main__":
    app.run(main)
