import torch
from diffusers import FluxPipeline, DiffusionPipeline, SemanticStableDiffusionPipeline
from diffusers import StableDiffusion3Pipeline
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import SafeStableDiffusion3Pipeline, SemanticStableDiffusion3Pipeline
import pdb
from tqdm import tqdm
import json
import os
import gc
from tqdm import tqdm
from openai import OpenAI
import torch
from diffusers import FluxPipeline
import pdb
import os

SEED = 42
num_prompts_per_company = 10

with open("companies.json", "r", encoding="utf-8") as file:
    companies = json.load(file)

device = "cuda"
model_name = "stabilityai/stable-diffusion-3-medium-diffusers"
methods = [ "before_unlearn", "np", "sld_v1", "sld_v2", "sld_v3", "sega_v1", "sega_v2", "sega_v3", "ours"]

prompt_mode = 'explicit'
# prompt_mode = 'implicit'

if prompt_mode == 'explicit':
    baseline_images_dir = "baseline_images_explicit"
    logo_prompt_dir = "logo_prompt_explicit"
    logo_prompt_ours_dir = 'logo_prompt_ours_explicit'
if prompt_mode == 'implicit':
    baseline_images_dir = "baseline_images_implicit"
    logo_prompt_dir = "logo_prompt_implicit"
    logo_prompt_ours_dir = 'logo_prompt_ours_implicit'

@torch.inference_mode()
def generate_images(company, prompt_item, i, pipe, method):

    actual_prompt = prompt_item.split('. ', 1)[1] if '. ' in prompt_item else prompt_item
    print(f"processing {company} {i}th prompt: {actual_prompt}")
    concept = f"{company} logo"
    generator = torch.Generator("cuda").manual_seed(SEED)

    output_path = f"{baseline_images_dir}/{method}/{company}/{i}.png"
    if os.path.exists(output_path):
        print(f"image already exists: {output_path}, skip")
        return

    with torch.no_grad():
        if method == "before_unlearn":
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                num_images_per_prompt=1,
                editing_prompt=None,
                reverse_editing_direction=[True],
                edit_warmup_steps=[10],
                edit_guidance_scale=[4],
                edit_threshold=[0.99],
                edit_momentum_scale=0.3,
                edit_mom_beta=0.6,
                height=1024,
                width=1024,
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")
            torch.cuda.empty_cache()
            gc.collect()

        if method == "ours":
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                num_images_per_prompt=1,
                editing_prompt=None,
                reverse_editing_direction=[True],
                edit_warmup_steps=[10],
                edit_guidance_scale=[4],
                edit_threshold=[0.99],
                edit_momentum_scale=0.3,
                edit_mom_beta=0.6,
                height=1024,
                width=1024,
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")
            
            torch.cuda.empty_cache()
            gc.collect()


        if method == "np":    
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                num_images_per_prompt=1,
                editing_prompt=None,
                reverse_editing_direction=[True],
                edit_warmup_steps=[10],
                edit_guidance_scale=[4],
                edit_threshold=[0.99],
                edit_momentum_scale=0.3,
                edit_mom_beta=0.6,
                height=1024,
                width=1024,
                negative_prompt=concept
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")
            
            torch.cuda.empty_cache()
            gc.collect()

        if method == "sld_v1":
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                sld_warmup_steps=7,
                sld_guidance_scale=2000,
                sld_threshold=0.025,
                sld_momentum_scale=0.5,
                sld_mom_beta=0.7,
                height=1024,
                width=1024,
                safety_concept=concept
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")
            torch.cuda.empty_cache()
            gc.collect()

        if method == "sld_v3":
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                sld_warmup_steps=0,
                sld_guidance_scale=5000,
                sld_threshold=1,
                sld_momentum_scale=0.5,
                sld_mom_beta=0.7,
                height=1024,
                width=1024,
                safety_concept=concept
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")
            torch.cuda.empty_cache()
            gc.collect()

        if method == "sld_v2":
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                sld_warmup_steps=4,
                sld_guidance_scale=3000,
                sld_threshold=0.5,
                sld_momentum_scale=0.5,
                sld_mom_beta=0.7,
                height=1024,
                width=1024,
                safety_concept=concept
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")
            torch.cuda.empty_cache()
            gc.collect()

        if  method == "sega_v1":    
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                num_images_per_prompt=1,
                editing_prompt=[concept],
                reverse_editing_direction=[True],
                edit_warmup_steps=[10],
                edit_guidance_scale=[4],
                edit_threshold=[0.99],
                edit_momentum_scale=0.3,
                edit_mom_beta=0.6,
                height=1024,
                width=1024
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")

            torch.cuda.empty_cache()
            gc.collect()

        if  method == "sega_v3":    
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                num_images_per_prompt=1,
                editing_prompt=[concept],
                reverse_editing_direction=[True],
                edit_warmup_steps=[5],
                edit_guidance_scale=[5],
                edit_threshold=[0.9],
                edit_momentum_scale=0.3,
                edit_mom_beta=0.6,
                height=1024,
                width=1024
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")

            torch.cuda.empty_cache()
            gc.collect()

        if  method == "sega_v2":    
            print(f"start generating {company} {method} {i}th image")
            out = pipe(
                prompt=actual_prompt,
                generator=generator,
                num_images_per_prompt=1,
                editing_prompt=[concept],
                reverse_editing_direction=[True],
                edit_warmup_steps=[7],
                edit_guidance_scale=[5],
                edit_threshold=[0.95],
                edit_momentum_scale=0.3,
                edit_mom_beta=0.6,
                height=1024,
                width=1024
            )
            image = out.images[0]
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            image.save(output_path)
            print(f"save {company} {method} {i}th image")

            torch.cuda.empty_cache()
            gc.collect()



for method in methods:
    print(f"processing {method} method")
    if method == "np":
        pipe = SemanticStableDiffusion3Pipeline.from_pretrained(model_name).to(device)
    elif method == "sld_v1" or method == "sld_v2" or method == "sld_v3":
        pipe = SafeStableDiffusion3Pipeline.from_pretrained(model_name).to(device)
    elif method == "sega_v1" or method == "sega_v2" or method == "sega_v3":
        pipe = SemanticStableDiffusion3Pipeline.from_pretrained(model_name).to(device)
    elif method == "ours":
        pipe = SemanticStableDiffusion3Pipeline.from_pretrained(model_name).to(device)
    elif method == "before_unlearn":
        pipe = SemanticStableDiffusion3Pipeline.from_pretrained(model_name).to(device)


    for company in tqdm(companies):
        if method == 'ours':
            prompts = []
            company_path = f"{logo_prompt_ours_dir}/{company}"
            for i in range(1, 11):
                file_path = os.path.join(company_path, f"{i}.json")

                if os.path.isfile(file_path):
                    with open(file_path, 'r', encoding='utf-8') as file:
                        data = json.load(file)
                        prompt_value = data.get("prompt_3")
                        if prompt_value is not None:
                            prompts.append(prompt_value)

                else:
                    print(f"can't find {company} {i}.json file, skip")
                    continue
            
        else:   
            json_path = f"{logo_prompt_dir}/{company}.json"
            if not os.path.exists(json_path):
                print(f"can't find {company}.json file, skip")
                continue

            with open(json_path, 'r', encoding='utf-8') as f:
                prompts = json.load(f)

        for i, prompt_item in enumerate(prompts[:num_prompts_per_company], 1):
            generate_images(company, prompt_item, i, pipe, method)
            
        torch.cuda.empty_cache()
        gc.collect()
        
    del pipe
    torch.cuda.empty_cache()
    gc.collect()







