import cv2 
import os
import numpy as np
import argparse
from utils import rdp, resize_video_to_length, resize_state_to_length, clip_and_resize, rdp_fix_frames
import logging
import pdb
from tqdm import tqdm

logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='libero_dataset', help='path to dataset')
parser.add_argument('--save_folder', type=str, default='libero_dataset/libero_90/keyframes', help='path to save keyframes')
# alternative method: RPD, fixed interval, original
parser.add_argument('--method', type=str, choices=['rdp', 'fixed', 'original'], default='rdp', help='method to select keyframes')
parser.add_argument('--target_frames', type=int, default=81, help='target number of frames')

args = parser.parse_args()

## python3 make_keyframe.py --dataset_path ./libero_object_episode --save_folder ./finetune_dataset/libero_object_rpd17 --method rdp --target_frames 17
if __name__ == '__main__':
    if not os.path.exists(args.save_folder):
        os.makedirs(args.save_folder)
        os.makedirs(os.path.join(args.save_folder, 'videos'))
    else: 
        print(f'Save folder {args.save_folder} already exists')
        # exit()

    VIDEO_TXT_PATH = os.path.join(args.save_folder, 'videos.txt')
    PROMPT_TXT_PATH = os.path.join(args.save_folder, 'prompts.txt')
    IDX_TXT_PATH = os.path.join(args.save_folder, 'idx.txt')
    NUM_FRAME = 81

    print("VIDEO_PATH ", VIDEO_TXT_PATH, "\nPROMPT_PATH ",PROMPT_TXT_PATH)
    # clear the file
    with open(VIDEO_TXT_PATH, 'w') as f:
        f.write('')
    with open(PROMPT_TXT_PATH, 'w') as f:
        f.write('')
    if args.method == 'rdp' or args.method == 'fixed':
        with open(IDX_TXT_PATH, 'w') as f:
            f.write('')

    tasks = os.listdir(args.dataset_path)
    for task in tasks:
        print(f'Processing task: {task}')
        task_path = os.path.join(args.dataset_path, task)
        task_prompt = task.replace("_"," ")
        episodes = os.listdir(task_path)
        for episode in tqdm(episodes):
            images = []
            image_folder = os.path.join(task_path, episode, 'imgs')
            pic_paths = os.listdir(image_folder)
            pic_paths = sorted(pic_paths, key=lambda x: int(x.split('.')[0].split('_')[-1]))
            # select every 2 
            # pic_paths = pic_paths[::2]   ### 32fps -> 16 fps for cogvideox

            for image in pic_paths:
                image_path = os.path.join(image_folder, image)
                image = cv2.imread(image_path)
                images.append(image)
            images = np.array(images)
            ## Interp the images to 81 frames
            images = resize_video_to_length(images, NUM_FRAME)
            state_path = os.path.join(task_path, episode, 'state.npy')
            states = np.load(state_path)
            states = resize_state_to_length(states, NUM_FRAME)

            assert len(images) == len(states) == 81

            if args.method == 'rdp':
                _, idx = rdp_fix_frames(states, args.target_frames)
                images = images[idx]
                assert images.shape[0] == args.target_frames
                
            elif args.method == 'fixed':
                idx = np.linspace(0, len(images)-1, args.target_frames, dtype=int)
                images = images[idx]
            # print(images.shape)
            images = clip_and_resize(images, 768, 1360)
            # print(images.shape)
            try:
                assert images.shape[0] % 8 == 1
            except:
                print(f'Error: {task}_{episode} has {images.shape[0]} keyframes')
                exit()
            
            video_save_path = os.path.join(args.save_folder, "videos", f'{task}_{episode}.mp4')
            with open(VIDEO_TXT_PATH, 'a') as f:
                f.write(f'./videos/{task}_{episode}.mp4\n')
            with open(PROMPT_TXT_PATH, 'a') as f:
                f.write(f'{task_prompt}\n')
            vid_writer = cv2.VideoWriter(video_save_path, cv2.VideoWriter_fourcc(*'mp4v'), 16, (1360, 768))
            if args.method == 'rdp' or args.method == 'fixed':
                # save idx either
                with open(IDX_TXT_PATH, 'a') as f:
                    for i in idx:
                        f.write(f'{i} ')
                    f.write('\n')

            for image in images:
                vid_writer.write(image)
            vid_writer.release()

            # pdb.set_trace()

