import argparse
import os
import time
import json
import sys
from utils import set_seed

from diffusers import StableDiffusionPipeline
import torch
import logging

import pickle
# from datasets import load_dataset
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')


def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

reset_every = 20

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--skip_warmup", action="store_true")
    parser.add_argument("--mask_threshold", type=str, default="100")
    # add argument "interval: could be separated by comma or just singular value, e.g. 1,2,3 or 1"
    parser.add_argument("--interval", type=str, default="1")
    args, _ = parser.parse_known_args()

    from highlight.runner import HighlightRunner as Runner

    parser = Runner.modify_commandline_options(parser)
    args = parser.parse_args()
    # runner = Runner(args)
    # logging.info("Running sige...")
    if args.prompt is None and args.json_path is None: 
        raise ValueError("Either --prompt or --json_path should be provided")
    
    # dataset = load_dataset("HuggingFaceM4/COCO", name="2014_captions", split="validation")
    #with open('/data1/diffusion_data/COCO_2014_val/prompts.json', 'r') as file:
    #    data = json.load(file)
    try:
        mask_threshold = [int(i) for i in args.mask_threshold.split(",")]
    except:
        raise ValueError("mask_threshold should be separated by comma or just singular value, e.g. 100,200,300 or 100")
    try:
        intervals = [int(i) for i in args.interval.split(",")]
    except:
        raise ValueError("interval should be separated by comma or just singular value, e.g. 1,2,3 or 1")
    
    time_list = []
    for threshold in mask_threshold:
        for interval in intervals:
            total_time = 0

            if args.prompt.endswith(".json"):
                with open(args.prompt, 'r') as file:
                    data = json.load(file)
                # If data has "prompt" key, use it
                if "Prompt" in data[0]:
                    data = [d["Prompt"] for d in data]
            else:
                data = [args.prompt]

            for i in range(1):
                args.prompt = data[0]
                runner = Runner(args)
                logging.info("Warming up GPU...")
                set_seed(args.seed)
                start_time = time.time()
                try:
                    original_stdout = sys.stdout
                    original_stderr = sys.stderr
                    sys.stdout = open(os.devnull, 'w')
                    sys.stderr = open(os.devnull, 'w')
                    runner.run(interval = interval,prompt_i = i,threshold = threshold,save_pkl = False)
                finally:
                    sys.stdout = original_stdout
                    sys.stderr = original_stderr
                use_time = time.time() - start_time
                #logging.info("Warmpup {}/50: {:.2f} seconds".format(i, use_time))
            
            for i in range(len(data)):
                args.prompt = data[i]
                if i % reset_every == 0:
                    del runner
                    torch.cuda.empty_cache()
                    runner = Runner(args)
                set_seed(args.seed)
                start_time = time.time()
                runner.run(interval = interval,prompt_i = i,threshold = threshold,save_pkl = True)
                use_time = time.time() - start_time
                logging.info("Highlight Pipeline with interval={}, threshold={}: {:.2f} seconds".format(interval, threshold, use_time))