
from main_utils import * 
import os, shutil
import openai 
from openai import AzureOpenAI
from dotenv import load_dotenv
from tqdm import tqdm 
import torch
import cv2, av
import numpy as np
import random 
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer
import time
import logging
# import wandb
from tqdm import tqdm
from typing import List
import argparse
import torchvision.transforms as transforms
from torchvision.transforms import Resize
from torchvision.utils import save_image
from diffusers import StableDiffusionXLPipeline
import requests
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.bleu.bleu import Bleu


## Code for main_iter.py 


if __name__=='__main__' : 

    parser = argparse.ArgumentParser(description="Video Selection")
    parser.add_argument("--eval_section", type=str, default='count', help="evalcrafter section=[action, amp, color, count, face, none, text]")
    parser.add_argument("--output_root", type=str, help="eval path root")
    parser.add_argument("--data", type=str, default ="evalcrafter")
    parser.add_argument("--round", type=int, default=5)
    args = parser.parse_args()

    ROOT = args.output_root
    EVAL_SECTION = args.eval_section 
    CUR_ROUND = 0      # default 

    print('EVAL_SECTION : ', EVAL_SECTION)
    cur_result_root = ROOT + EVAL_SECTION

    init_path = os.path.join(cur_result_root, 'init_videos')
    init_prompt_path = os.path.join(cur_result_root, 'init_prompts')

    os.makedirs(init_path, exist_ok=True)
    os.makedirs(init_prompt_path, exist_ok=True)

    if args.data == 'evalcrafter' : 
        dsg_json_path = '/datasets/our_dsg_depend_v2/dsg_' + EVAL_SECTION + '.json'  
    elif args.data == 'vbench' : 
        dsg_json_path = '/datasets/our_dsg_depend_v2/vbench/' + EVAL_SECTION + '.json'   
    
    elif args.data == 'compbench' : 
        dsg_json_path = '/datasets/our_dsg_depend_v2/compbench/' + EVAL_SECTION + '.json'   

    with open(dsg_json_path, 'r') as json_file:
        meta_data = json.load(json_file)

    INIT_ALL_SCORE = 0 ; FINAL_ALL_SCORE = 0 

    for CUR_ROUND in range(args.round) : 
        os.makedirs(os.path.join(cur_result_root, 'best_videos_' + str(CUR_ROUND)), exist_ok=True)

    for i, INIT_PROMPT_IDX in tqdm(enumerate(range(len(meta_data) )) ):    
        print(str(INIT_PROMPT_IDX) + ' processing')

        cur_meta = meta_data[INIT_PROMPT_IDX]              
        init_prompt = cur_meta['origin_prompt']            
        qid2question = cur_meta['qid2question']  
        qid2dependency = cur_meta['qid2dependency']  
        unique_id = cur_meta['idx']
        del cur_meta["origin_prompt"]
        del cur_meta["idx"]

        # names 
        video_name = f"{INIT_PROMPT_IDX+1:04}.mp4"
        prompt_name = f"{INIT_PROMPT_IDX:04}.txt"

        cur_gen_dir  = os.path.join(cur_result_root, str(unique_id), 'best_videos_dsg_blip')
        videos = sorted([f for f in os.listdir(cur_gen_dir) if f.endswith('.mp4')])
            
        if len(videos) > 0 :

            for CUR_ROUND in range(args.round) : 
                cur_round_best_path = os.path.join(cur_result_root, 'best_videos_' + str(CUR_ROUND))
                if os.path.exists(os.path.join(cur_round_best_path, f"{i+1:04}.mp4")) : 
                    continue 
                print('Best path: ', cur_round_best_path)

                cur_best_name = f"{CUR_ROUND}_best_video.mp4"
                cur_best_path = os.path.join(cur_gen_dir, cur_best_name)

                if os.path.exists(cur_best_path) : 
                    shutil.copy(cur_best_path, os.path.join(cur_round_best_path, video_name))   

                else : 
                    cur_best_path = sorted([f for f in os.listdir(cur_gen_dir) if f.endswith('.mp4')])[-1]
                    shutil.copy(os.path.join(cur_gen_dir, cur_best_path), os.path.join(cur_round_best_path, video_name))   

            # Init video + prompt 
            cur_init_video_path = os.path.join(os.path.join(cur_result_root, str(unique_id)), '0_round', init_prompt+'_0.mp4')
            shutil.copy(cur_init_video_path, os.path.join(init_path, video_name))   
            
            with open( os.path.join(init_prompt_path, prompt_name), 'w', encoding='utf-8') as file:   # Init prompt dir 
                file.write(init_prompt)
                video_prompt = init_prompt 


        else : 
            print('There are no videos in best_videos_dsg_blip')
            cur_gen_dir  = os.path.join(cur_result_root, str(unique_id), '0_round')
            cur_init_video_path = sorted([f for f in os.listdir(cur_gen_dir) if f.endswith('.mp4')])[0]
            
            shutil.copy(os.path.join(cur_gen_dir, cur_init_video_path), os.path.join(init_path, video_name))   

            for CUR_ROUND in range(args.round) : 
                cur_round_best_path = os.path.join(cur_result_root, 'best_videos_' + str(CUR_ROUND))
                if os.path.exists(os.path.join(cur_round_best_path, f"{i+1:04}.mp4")) : 
                    continue 
                print('Best path: ', cur_round_best_path)
                shutil.copy(os.path.join(cur_gen_dir, cur_init_video_path), os.path.join(cur_round_best_path, video_name))   

            with open( os.path.join(init_prompt_path, prompt_name), 'w', encoding='utf-8') as file:   # Init prompt dir 
                file.write(init_prompt)
                video_prompt = init_prompt 