import torch
import cv2
import numpy as np
import os
import bisect
from PIL import Image
from torchvision.transforms import functional as F
from util import load_image
from interpolator import Interpolator
from tqdm import tqdm
import argparse
import time

def inference(model, img1, img2, device, inter_frames):
    model.eval()
    model.to(device)
    img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) / 255.0
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) / 255.0
    img1 = cv2.resize(img1, (1360, 768))
    img2 = cv2.resize(img2, (1360, 768))
    img1 = np.expand_dims(img1, axis=0)
    img2 = np.expand_dims(img2, axis=0)
    img_batch_1 = torch.from_numpy(img1).permute(0, 3, 1, 2)
    img_batch_2 = torch.from_numpy(img2).permute(0, 3, 1, 2)

    results = [
        img_batch_1,
        img_batch_2
    ]

    idxes = [0, inter_frames + 1]
    remains = list(range(1, inter_frames + 1))

    splits = torch.linspace(0, 1, inter_frames + 2)

    for _ in range(len(remains)):
        starts = splits[idxes[:-1]]
        ends = splits[idxes[1:]]
        distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
        matrix = torch.argmin(distances).item()
        start_i, step = np.unravel_index(matrix, distances.shape)
        end_i = start_i + 1

        x0 = results[start_i].to(device=device, dtype=precision)
        x1 = results[end_i].to(device=device, dtype=precision)

        dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])

        with torch.no_grad():
            prediction = model(x0, x1, dt)
        insert_position = bisect.bisect_left(idxes, remains[step])
        idxes.insert(insert_position, remains[step])
        results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
        del remains[step]

    frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results]
    return frames

def make_video_from_keyframe(model, keyframes_path, keyframe_gap):
    # read keyframes from mp4
    video = cv2.VideoCapture(keyframes_path)
    keyframes = []
    while video.isOpened():
        ret, frame = video.read()
        if not ret:
            break
        keyframes.append(frame)
    video.release()
    keyframes = [cv2.resize(keyframes[i], (1360 // 2, 768 // 2)) for i in range(len(keyframes))]
    # generate video from keyframes
    frames = [] 
    frames.append(keyframes[0])
    for i in range(len(keyframe_gap)):
        try:
            assert keyframe_gap[i] >= 1
        except AssertionError:
            print(f"keyframe_gap[{i}] = {keyframe_gap[i]}")
            exit(1)
        if keyframe_gap[i] == 1:
            frames.append(keyframes[i+1])
        else:
            interpolated = inference(model, keyframes[i], keyframes[i+1], device, keyframe_gap[i]-1)
            frames += interpolated[1:]
    frames = [cv2.resize(frame, (1360, 768)) for frame in frames]
    return frames

def gen_batch_video(task_folder, save_folder, model):
    logging_file = os.path.join(save_folder, "log.txt")
    time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    with open(logging_file, 'w') as f:
        f.write(f"Start generating videos at {time_now}\n")
    video_txt = "video_name.txt"
    video_txt = os.path.join(task_folder, video_txt)
    video_names = []
    with open(video_txt, 'r') as f:
        for line in f:
            video_names.append(line.strip())
        
    idx_txt = "idx.txt"
    idx_txt = os.path.join(task_folder, idx_txt)
    idxs = []
    with open(idx_txt, 'r') as f:
        for line in f:
            idx = [int(i) for i in line.strip().split(' ')]
            idxs.append(idx)

    for i in tqdm(range(len(video_names))):
        video_name = video_names[i].split("/")[-1]
        keyframe_gap = idxs[i]
        frames = make_video_from_keyframe(model, os.path.join(task_folder, video_name), keyframe_gap)
        # save video
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(os.path.join(save_folder, video_name), fourcc, 16.0, (frames[0].shape[1], frames[0].shape[0]))
        for frame in frames:
            out.write(frame)
        out.release()
        time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        with open(logging_file, 'a') as f:
            f.write(f"Finish generating video {video_name} at {time_now}\n")

    print("Done!")
        


if __name__ == '__main__':
    

    parser = argparse.ArgumentParser()
    parser.add_argument('--task_folder', type=str, default='/path to libero_dataset/inference_result/libero_object_rpd17')
    parser.add_argument('--save_folder', type=str, default='/path to libero_dataset/final_result/libero_object_rpd17')
    parser.add_argument('--model_path', type=str, default='./ckpts/model_final.pth')
    parser.add_argument('--device', type=str, default='cuda:6')
    args = parser.parse_args()

    device = torch.device(args.device)
    precision = torch.float32
    model = Interpolator()
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
    model.eval().to(device=device, dtype=precision)
    if not os.path.exists(args.save_folder):
        os.makedirs(args.save_folder)

    gen_batch_video(args.task_folder, args.save_folder, model)