import requests
import json
import os
from base64 import b64encode, b64decode
import urllib3
import cv2
import numpy as np
from PIL import Image
from io import BytesIO
import sys
import imageio.v3 as iio
import time
import torch
import torchvision


def extract_first_k_frames(video_path, k):
    # Read the first k frames from the video
    frames = []
    reader = iio.imiter(video_path, plugin="FFMPEG")
    for idx, frame in enumerate(reader):
        if idx >= k:
            break
        frames.append(torch.tensor(frame, dtype=torch.uint8))

    jpeg_message_list = []
    for frame in frames:
        jpeg_tensor = torchvision.io.encode_jpeg(frame.permute(2, 0, 1))
        jpeg_message_list.append(b64encode(jpeg_tensor.numpy().tobytes()).decode("utf-8"))

    return jpeg_message_list


def extract_first_k_frames_to_mp4_bytes(video_path, k):
    # Read the first k frames from the video
    frames = []
    reader = iio.imiter(video_path, plugin="FFMPEG")
    for idx, frame in enumerate(reader):
        if idx >= k:
            break
        frames.append(frame)

    # Write to BytesIO as mp4
    output_bytes = BytesIO()
    iio.imwrite(output_bytes, frames, extension=".mp4", fps=8)
    # with iio.imopen(output_bytes, "w", plugin="pyav") as writer:
    #     writer.init_video_stream("libx264", fps=8)  # set codec & fps
    #     for frame in frames:
    #         writer.write_frame(frame)

    output_bytes.seek(0)
    return output_bytes


def decode_frames(imgs):
    frame_data = []
    for img in imgs:
        img = torch.frombuffer(b64decode(img), dtype=torch.uint8)
        frame_data.append(torchvision.io.decode_jpeg(img, mode=torchvision.io.ImageReadMode.RGB))
    frame_data = torch.stack(frame_data, dim=1)
    return frame_data


def main():
    data = {"prompt": "", "imgs": "", "num_conditional_frames": 1, "num_new_frames": int(sys.argv[2]), "seed": 1234, "num_sampling_step": 5, "guide_scale": 5.0, "password": "r49h8fieuwK", "return_imgs": False, "clean_cache": False}
    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
    if len(sys.argv) > 3:
        prompt_file = f"./datasets/{sys.argv[3]}"
    else:
        prompt_file = "./datasets/0710_aloha_3/metas/0710_aloha_3_fold_janes_twice_using_both_arms_0.txt"
    with open(prompt_file, 'r') as f:
        data["prompt"] = f.read()
    video_path = prompt_file.replace("metas/", "videos/").replace(".txt", ".mp4")

    # video_bytes = extract_first_k_frames_to_mp4_bytes(video_path, int(sys.argv[2]))
    # data["video"] = b64encode(video_bytes.getvalue()).decode('utf-8')

    imgs = extract_first_k_frames(video_path, 113)
    out_imgs = []
    out_masks = []
    # decode_frames(data["imgs"])
    
    headers = {
        "Content-Type": "application/json",
    }

    data["return_imgs"] = True

    t0 = time.time()
    num_conditional_frames = 1
    for i in range(1, 113, int(sys.argv[2])):
        if num_conditional_frames + int(sys.argv[2]) > 60:
            num_conditional_frames = 33
            data["imgs"] = imgs[max(0, i - num_conditional_frames): i]
            data["num_conditional_frames"] = num_conditional_frames
            data["clean_cache"] = True
        else:
            num_conditional_frames += int(sys.argv[2])
            data["imgs"] = imgs[max(0, i - int(sys.argv[2])): i]
            data["num_conditional_frames"] = num_conditional_frames
            data["clean_cache"] = False
        if data["return_imgs"]:
            t = time.time()
            response_full = requests.post(f"http://localhost:{sys.argv[1]}", headers=headers, data=json.dumps(data), verify=False).json()
            out_imgs += response_full["imgs"]
            out_masks += response_full["masks"]
            print(f"full time usage at step {i}:", time.time() - t)

            try:
                out_dir = "./output/client"
                out_video_file = os.path.join(out_dir, "pred.mp4")
                fps = 8
                width, height = 640, 736
                os.makedirs(out_dir, exist_ok=True)
                out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
                for i, v in enumerate(out_imgs):
                    img = cv2.imdecode(np.frombuffer(b64decode(v), np.uint8), cv2.IMREAD_COLOR)
                    cv2.imwrite(os.path.join(out_dir, f"pred_{i}.jpg"), img)
                    out.write(img)
                out.release()

                out_video_file = os.path.join(out_dir, "mask.mp4")
                out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
                for i, v in enumerate(out_masks):
                    img = cv2.imdecode(np.frombuffer(b64decode(v), np.uint8), cv2.IMREAD_COLOR)
                    cv2.imwrite(os.path.join(out_dir, f"mask_{i}.jpg"), img)
                    out.write(img)
                out.release()
            except:
                breakpoint()
        else:
            t = time.time()
            response_action = requests.post(f"http://localhost:{sys.argv[1]}", headers=headers, data=json.dumps(data), verify=False).json()
            print(f"time usage at step {i}:", time.time() - t)
    print(f"time usage: {time.time() - t0}")


if __name__ == "__main__":
    main()
