import numpy as np
import os
import json
from tqdm import tqdm
import torch

from diffusers import StableDiffusionPipeline

"""This script creates the dataset to train concept vectors. 
This include generating images from prompts using Stable Diffusion 
and saving the images and labels in a folder."""


def update_concept_dict():
    concept_dict = ["woman", "man", "young", "old"]
    concept_dict = {c:i for i,c in enumerate(concept_dict)}
    return concept_dict


def repeat_ntimes(x, n):
    return [item for item in x for i in range(n)]


class DataCreator:
    def __init__(self, cfg):
        self.resolution = cfg.resolution
        self.root_dir = cfg.root_dir
        self.image_prompt = repeat_ntimes(cfg.image_prompt, cfg.num_samples)
        print(f"to create {len(self.image_prompt)} total number of samples in {cfg.root_dir}")

    def create_images(self, num_inference_steps=30):
        pipe = StableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4", 
            torch_dtype=torch.float16
            )
        pipe = pipe.to("cuda")
        pipe.safety_checker=None
        pipe.set_progress_bar_config(disable=True)

        os.makedirs(self.root_dir, exist_ok=True)
        for idx, prompt in tqdm(enumerate(self.image_prompt), total=len(self.image_prompt)):
            if isinstance(prompt, list) or isinstance(prompt, tuple):
                output = pipe(prompt[0], 
                    height=self.resolution, width=self.resolution,
                    negative_prompt=prompt[1], num_inference_steps=num_inference_steps, return_dict=True
                    )
            else:
                output = pipe(prompt, 
                    height=self.resolution, width=self.resolution,
                    num_inference_steps=num_inference_steps, return_dict=True
                    )
            image = output[0][0]
            image.save(self.root_dir+"/"+f"{idx}.jpg")

    def run(self):
        self.create_images()

def syn_generation(dataset, idx):
    prompt_path = f"/root/InterpretDiffusion/{dataset}_prompts.txt"
    ct = 0
    with open(prompt_path, 'r') as f:
        lines = f.readlines()
        mm = {
            "officehome": 5,
            "pacs": 6,
            "domainnet": 5,
            'ucm': 5,
            'dermamnist': 1,
        }
        pipe = StableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4", 
            torch_dtype=torch.float16
            )
        pipe = pipe.to("cuda")
        pipe.safety_checker=None
        pipe.set_progress_bar_config(disable=False)
        sp = idx
        for ctt, l in enumerate(lines):
            if not (ctt >= sp * (len(lines)//5) and ctt < (sp+1) * (len(lines)//5)):
                ct += 20
                continue

            cls = l.split("/")[mm[dataset]]   
            if '.png' in l:
                prompt = l.split(".png ")[1]         
            else:
                prompt = l.split(".jpg ")[1]
            os.makedirs(f"datasets_fgl/{dataset}/{cls}", exist_ok=True)
            root_dir = f"datasets_fgl/{dataset}/{cls}"
            for idx in range(20):
                output = pipe(prompt, 
                    height=512, width=512,
                    negative_prompt=prompt, 
                    num_inference_steps=50, return_dict=True
                    )
                image = output[0][0]
                image.save(root_dir+"/"+f"{ct}.jpg")
                ct += 1

def syn_base_generation():
    cls = "dermatofibroma"
    class Cfg_class:
        root_dir=f"datasets_dermamnist/{cls}"
        num_samples=200
        resolution=512

        image_prompt = [
            "A dermatoscopic image of a dermatofibroma, a type of pigmented skin lesions."
        ]

    creator=DataCreator(Cfg_class)
    creator.run()

def syn_style_generation(args, dataset, domain, categories):
    for cls in categories:
        print(f"Creating dataset for {cls}, {domain} in {dataset}")
        args.root_dir=f"datasets_{dataset}/{cls}/{domain}"
        args.num_samples=80
        args.image_prompt = [
            f'a {args.domain.replace("_", " ")} style of a {cls}'
        ]

        creator=DataCreator(args)
        creator.run()

if __name__=="__main__":
    # add argparser
    import argparse
    from config import parse_args

    args = parse_args()
    syn_style_generation(args, args.dataset, args.domain, args.categories)


