import os
import datetime
import json
import torch
import argparse
import numpy as np

from diffusers import StableDiffusionPipeline
from glob import glob

from utils import log_creator, \
                    set_seed, \
                    print_args, savemodelDiffusers

def prompts_loader(prompt_set: dict):
    ret = []
    f = open(prompt_set["path"], "r")
    for idx, line in enumerate(f.readlines()):
        line = json.loads(line)
        line["idx"] = idx
        ret.extend([line] * prompt_set["num_per_item"] * prompt_set["times"])
    return ret


def get_ckpt_from_log(log_path, name="Stable_Diffusion"):
    files = glob(os.path.join(log_path, "checkpoints", name + "_epoch_*.pth"))
    if len(files) == 0:
        raise FileNotFoundError
    files = sorted(files, key=lambda x: int(x.split("_epoch_")[-1].split(".pth")[0]))[-1]
    return files

def main(cfg):
    # load task config
    tsk_cfg = cfg["task_config"]

    # set seed
    set_seed(tsk_cfg["seed"])
    
    # create log
    logger = log_creator(
        os.path.join(tsk_cfg['log_path'], "generate." + 
                        str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M")) + ".log")
    )
    print_args(cfg, logger)

    # model
    ## generator
    pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float16).to("cuda")
    logger.info("Build generator!")

    state_dict_path = get_ckpt_from_log(tsk_cfg["log_path"])
    unet_dict = savemodelDiffusers(state_dict_path, None)
    pipe.unet.load_state_dict(unet_dict, False)
    pipe.safety_checker = None


    ## prompt_learners
    prompts = tsk_cfg["erased_concept"]
    if not isinstance(prompts, list):
        prompts = [prompts]

    # generate
    sve_path = tsk_cfg["generate"]["save_path"]
    prompt_sets = tsk_cfg["generate"]["prompt_sets"]
    batch_size = tsk_cfg["generate"]["batch_size"]

    for prompt_set in prompt_sets:
        if "common" in prompt_set["name"]:
            cur_sve_path = os.path.join(
                        sve_path, 
                        prompt_set["name"].split("_")[0])
            if not os.path.exists(cur_sve_path):
                os.makedirs(cur_sve_path)

            prompts = prompts_loader(prompt_set)
            i = 0
            while i < len(prompts):
                batch = prompts[i: min(i + batch_size, len(prompts))]
                prompts_batch = [x["prompt"] for x in batch]
                # set_seed(2024)
                samples = pipe(prompts_batch, num_inference_steps=50, guidance_scale=7.5).images

                for idx, sample in enumerate(samples):
                    sample.save(os.path.join(
                        cur_sve_path, 
                        str(batch[idx]["idx"])+ "_" + str(i + idx)+".png")
                    )
                i = i + batch_size
        elif "specifical" in prompt_set["name"]:
            cur_sve_path = os.path.join(
                        sve_path, 
                        prompt_set["name"].split("_")[0])
            if not os.path.exists(cur_sve_path):
                os.makedirs(cur_sve_path)
            
            x = torch.rand((256,256,256,150)).cuda()

            prompts = prompts_loader(prompt_set)
            batch_size_s = prompt_set["num_per_item"] * prompt_set["times"]
            i = 0
            while i < len(prompts):
                batch = prompts[i: min(i + batch_size_s, len(prompts))]
                prompts_batch = [x["prompt"] for x in batch]
                # prompts_batch = ["a sexual photo"] * batch_size_s

                if "sd_seed" in prompts[i].keys():
                    seed = prompts[i]["sd_seed"]
                    set_seed(seed)
                    
                if "sd_guidance_scale" in prompts[i].keys():
                    uc_guidance = prompts[i]["sd_guidance_scale"]
                    samples = pipe(prompts_batch, num_inference_steps=50, guidance_scale=uc_guidance).images
                else:
                    samples = pipe(prompts_batch, num_inference_steps=50, guidance_scale=7.5).images
                for idx, sample in enumerate(samples):
                    sample.save(os.path.join(
                        cur_sve_path, 
                        str(batch[idx]["idx"])+ "_" + str(i + idx) + ".png")
                    )

                i = i + batch_size_s
            
            del x
            torch.cuda.empty_cache()


if __name__ == "__main__":
    arg = argparse.ArgumentParser()
    arg.add_argument(
        "--config_file",
        default="",
        type=str,
    )
    arg = arg.parse_args()

    f = open(arg.config_file, "r")
    cfg = json.load(f)

    main(cfg)
