# -*- coding: utf-8 -*-
"""quick_samples (1).ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1U9Y1CG3TWBu1VKtLUzhkAeK9UjcyKkQL
"""

from diffusers import StableDiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline
import torch
import os
import argparse
import prompts

torch.set_grad_enabled(False)

parser = argparse.ArgumentParser(
    description="Old server samples"
)

parser.add_argument(
    "--model_name",
    type=str,
    default=None,
    help='Name of the model, e.g. "beta2000_64acc_600epoch".'
)

args = parser.parse_args()


dpo_unet = UNet2DConditionModel.from_pretrained(
                            #  'mhdang/dpo-sd1.5-text2image-v1',
                            # 'mhdang/dpo-sdxl-text2image-v1',
                            # alternatively use local ckptdir (*/checkpoint-n/)
                            # 'sd15_2000/',
                            # 'sd15_2000_30kiter',
                            #'tmp-sd15',
                            args.model_name,
                            subfolder='unet',
                            torch_dtype=torch.float16
).to('cuda')

# pretrained_model_name = "CompVis/stable-diffusion-v1-4"
pretrained_model_name = "runwayml/stable-diffusion-v1-5"
# pretrained_model_name = "stabilityai/stable-diffusion-xl-base-1.0"
gs = (5 if 'stable-diffusion-xl' in pretrained_model_name else 7.5)

if 'stable-diffusion-xl' in pretrained_model_name:
    pipe = StableDiffusionXLPipeline.from_pretrained(
        pretrained_model_name, torch_dtype=torch.float16, device_map="balanced",
        variant="fp16", use_safetensors=True
    )#.to("cuda")
else:
    pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, device_map="balanced",
                                                   torch_dtype=torch.float16)
#pipe = pipe.to('cuda')
pipe.safety_checker = None # Trigger-happy, blacks out >50% of "robot tiger"

original_config = pipe.scheduler.config

from diffusers import EulerAncestralDiscreteScheduler
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(original_config)




# # Can do clip_utils, aes_utils, hps_utils
# from utils.pickscore_utils import Selector
from utils.hps_utils import Selector
# Score generations automatically w/ reward model
ps_selector = Selector('cuda')

unets = [pipe.unet, dpo_unet]
names = ["Orig. SDXL", "DPO SDXL"]

def gen(image_id, prompt, seed=0, run_baseline=True, output_dir="partial_trained_outputs"):
    ims = []
    generator = torch.Generator(device='cuda')
    
    # Ensure output directory exists
    #os.makedirs(output_dir, exist_ok=True)

    for unet_i in ([0, 1] if run_baseline else [1]):
        print(f"Prompt: {prompt}\nSeed: {seed}\n{names[unet_i]}")
        pipe.unet = unets[unet_i]
        generator = generator.manual_seed(seed)
        
        im = pipe(prompt=prompt, generator=generator, guidance_scale=gs).images[0]
        ims.append(im)
        
        # Build a filename with first word, seed, and unet name
        filename = f"{output_dir}/{image_id}_seed{seed}_{names[unet_i].replace(' ', '_')}.png"
        #im.save(filename)
        #print(f"Saved to {filename}")
        
    return ims

example_prompts = [
    "A pile of sand swirling in the wind forming the shape of a dancer",
    "A giant dinosaur frozen into a glacier and recently discovered by scientists, cinematic still",
    "a smiling beautiful sorceress with long dark hair and closed eyes wearing a dark top surrounded by glowing fire sparks at night, magical light fog, deep focus+closeup, hyper-realistic, volumetric lighting, dramatic lighting, beautiful composition, intricate details, instagram, trending, photograph, film grain and noise, 8K, cinematic, post-production",
    "A purple raven flying over big sur, light fog, deep focus+closeup, hyper-realistic, volumetric lighting, dramatic lighting, beautiful composition, intricate details, instagram, trending, photograph, film grain and noise, 8K, cinematic, post-production",
    "a smiling beautiful sorceress wearing a modest high necked blue suit surrounded by swirling rainbow aurora, hyper-realistic, cinematic, post-production",
    "Anthro humanoid turtle skydiving wearing goggles, gopro footage",
    "A man in a suit surfing in a river",
    "photo of a zebra dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography",
    "A typhoon in a tea cup, digital render",
    "A cute puppy leading a session of the United Nations, newspaper photography",
    "Worm eye view of rocketship",
    "Glass spheres in the desert, refraction render",
    "anthropmorphic coffee bean drinking coffee",
    "A baby kangaroo in a trenchcoat",
    "A towering hurricane of rainbow colors towering over a city, cinematic digital art",
    "A redwood tree rising up out of the ocean",
]

partipromt_50 = [
    'a family of four posing on the moon',
    'a family of bears passing by the geyser Old Faithful',
    'a high-quality photograph of an armadillo playing a bagpipe while standing on one leg',
    'The Statue of Liberty with the Manhattan skyline in the background.',
    'a milk container in a refrigerator',
    'a Christmas tree',
    'a dolphin in an astronaut suit on saturn',
    "the saying 'do unto others as they would do unto you' written on a white background",
    'a thumbnail image of a person skiing',
    'an old-fashioned cocktail next to a napkin',
    'a small kitchen with a white goat in it',
    'view of a clock tower from above',
    'A helicopter flies over Yosemite.',
    'A punk rock platypus in a studded leather jacket shouting into a microphone while standing on a boulder',
    'background pattern with alternating roses and skulls',
    'molecule',
    'Downtown Seattle at sunrise. detailed ink wash.',
    'a Christmas tree on a toy train',
    'a girl',
    'a red train is coming down the beach',
    'A shiny VW van that has flowers painted on it. A smiling sloth stands on grass in front of the van and is wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. ink sketch.',
    'the mona lisa',
    'The mouse the cat watches is jumping in the air.',
    'a white towel with a cartoon of a cat on it',
    'A smiling sloth wearing a leather jacket, a cowboy hat, a kilt and a bowtie. The sloth is holding a quarterstaff and a big book. A shiny VW van with a cityscape painted on it and parked on grass.',
    'a pen-and-ink crosshatched drawing of a sphere with dark square on it',
    'a hamster dragon',
    'A Vietnam map',
    'a coffee mug floating in the sky',
    'a flag',
    'an owl standing on a wire',
    'a car with tires that have yellow rims',
    'teacup',
    'a three quarters view of a man getting into a car',
    'A giant cobra snake made from pancakes',
    'a black dog jumping up to hug a woman wearing a red sweater',
    'a half moon in the day sky',
    'a tree growing out of the middle of an intersection',
    'three chairs',
    'a turkey',
    'a snail made of harp',
    'a sword in a stone',
    'a doorknocker',
    'an emoji of a baby penguin wearing a blue hat, red gloves, green shirt, and yellow pants',
    'A castle made of cardboard.',
    'ten wine bottles',
    'a girl riding an ostrich',
    'Mars rises on the horizon.',
    'a yellow diamond-shaped sign with a turtle silhouette',
    'A single beam of light enter the room from the ceiling. The beam of light is illuminating an easel. On the easel there is a Rembrandt painting of a raccoon'
]



# hpsv2 = ['A film still of Luke Skywalker as a Sith Lord.', 'A minimalistic heart drawing created using Adobe Illustrator.', 'Portrait of a male furry anthro mountain goat in a pinstripe suit and waistcoat, smoking a cigar.', 'Renaissance noblewoman with blue eyes and pale skin in a classical portrait pose in the art style of Ib Iwerks.', 'A close-up image of a woman wearing a samurai mask, fire dancing in a dirty cyberpunk alley with smoke and mist.', 'A lemon character with sunglasses on the beach.', 'Abstract yin yang representation by Ivan Bilibin.', 'An image of Akira, from the artist Simon Stalenhag.', 'Galactus devouring planet earth, depicted in an artwork by Francisco Goya.', 'A hyperrealistic mixed media image of a proportionally sized human hand undergoing particle teleportation, with perfect symmetry and dim volumetric lighting.', 'A dolphin swimming in front of a Studio Ghibli logo backdrop.', 'A man drinking cosmic energy in an anime-style digital art by Park Sung-woo.', 'Two messy toilet stalls with toilets where one lid is raised. ', 'A lemon character wearing sunglasses on the beach.', 'Undertale character Spromple Sploop, third brother of Sans.', 'A plane is on display near the water.', 'A young woman smiling in the etheric hypothalamus of her mind.', 'A portrait painting of a Red Borzoi Dog wearing a red beret as an Overwatch character.', 'A giant guardian wearing road sign armor, a popular character design on Artstation.', 'A bicycle leaned against the hallway wall in a house', 'A white toilet sitting under a window next to a chair.', 'A manga-style illustration of Harry Potter as a Gundam mech.', 'A Nintendo 64 controller with anthropomorphic features consuming small children.', 'A portrait painting of Yondu Udonta in an asymmetrical profile shot, incorporating bold shapes and hard edges with a stylized street art aesthetic.', 'A surreal portrait of a young Spanish man wearing sock and titled "Super Spy Captain" with deep purple hair and green eyes on an orange background.', 'A portrait of a stylized business cat in sharp focus with a medium shot perspective, resembling boxart.', 'A man sitting on a black and yellow bench on the phone.', 'A pen illustration of a man wrestling his phone by Gustave Doré with crosshatching and pops of colorful Ben Day dots.', 'A masterpiece.', 'Albus Dumbledore dressed up as Wonder Woman.']

hpsv2_50 = ['A teenage mutant ninja turtle, Leonardo, enjoys a cup of tea at a wooden desk in a sci-fi space station orbiting a large planet visible through a window.', 'The top of a steeped church building with clocks and small windows.  ', 'A bald general with an angry expression in an intricately detailed and elegant digital painting.', 'A furry cat girl.', 'Celine Dion appears angry at a kitten in a hot tub.', 'A vampire sits at a banquet table in a dungeon setting surrounded by plates of rats and spiders and red candles.', 'A bathroom with a small sink and toilet. ', 'Three small dinosaurs entering a grocery store painted by Thomas Kinkade.', 'Large sized kitchen with a dining room section.', '"Front centered symmetrical portrait of Elisha Cuthbert as a D&D paladin with cinematic lighting."', 'There are orange slices in canning jars without lids.', 'A portrait of an orc in a fantasy art style.', 'Exterior image of a small magic items and curios shop in a busy fantasy city.', 'A portrait of a character in a scenic environment.', 'A plane riding down a runway of an airport.', 'A stylized digital art image of a cherry tree overlooking a valley with a waterfall during sunset.', 'A girl sneaking behind a giant wooden door with archaic symbols embedded onto it, in a cave with the waterfall, illustrated in comics style.', 'Flag design for communist European Union featuring a hammer and sickle.', 'A digital painting of Teemo from League of Legends, wearing cyborg parts and a new skin, in a fantasy MMORPG style.', 'an empty bench sitting on the side of a sidewalk', 'Image of xqc with a distinctive underbite and big, long nose.', 'An abstract collage featuring grey and lilac colors with a touch of sparkle.', 'Patrick Bateman beating an anthropomorphic wolf cosplay.', 'An ultra-realistic illustration of a bird god swinging a gold metal stick weapon, with a blue man face and yellow bird mouth, and intricate traditional Chinese elements.', 'Several people standing next to each other that are snow skiing.', 'A little orange kitten sits on a pink heart-shaped pillow.', 'A pink bicycle leaning against a fence near a river.', 'A landscape painting of a China mountain village with a turbulent blood lake.', 'Abstract yin yang representation by Ivan Bilibin.', 'A half body portrait of an Asian cyberpunk mechanoid fashion idol wearing a neon jellyfish headdress and xenomorphic body suit.', 'A realistic anime painting of a cosmic woman wearing clothes made of universes with glowing red eyes.', 'A Walter White funko pop figurine.', 'A VTuber model concept art of a beautiful girl in a black and yellow hoodie looking on a smartphone in her hand, with blue eyes, long hair, and a futuristic city background.', 'A high detail portrait of a royal mansion by Michelangelo Merisi da Caravaggio.', "Animation keyframes featuring a wolf's walking motion.", 'A painted portrait of Persephone in ancient Greece with intricate detail, iridescent coloring, and golden hour lighting.', 'A digital painting by James Jean depicting a goddess in a strong pose surrounded by planets in a hyper-realistic style.', 'A digital painting of an anthropomorphic corgi lifting weights in a dim gym with intricate details and a dynamic pose.', 'A ps2 anime witch from madoka magicka is flying on a broom through New York causing people to run for their lives due to a terrorist attack.', 'A cat wearing a war helmet.', 'there is a woman that is cutting a white cake', 'A white stove top oven inside of a kitchen.', 'a jet airplane sitting on a runway next to a building', 'A neon-colored frog in a cyberpunk setting.', 'A woman in a bathing suit captured in an ink drawing by Sam Bosma with outlined and stippled details.', 'A girl looks out from the edge of a mountain onto a large city at night.', 'there is a very beautiful view out of this bathroom window', 'A female Sonic the Hedgehog with black sclera and bright red pupils.', 'A futuristic city with a lake, a reflection of utopia, and jungle scenery, featuring drones and androids.', 'A vivid and intricate depiction of a terrifying god-like creature with rich, bold colors and influences from various artists.']

all_prompts = partipromt_50 + hpsv2_50

all_scores = []

for i, p in enumerate(prompts.test_prompts, start=1):
    ims = gen(image_id=i, prompt=p, run_baseline=False)
    all_scores = all_scores + ps_selector.score(ims, p)

print(all_scores)
import statistics
print(statistics.median(all_scores))
print(statistics.mean(all_scores))

median_value = statistics.median(all_scores)
mean_value = statistics.mean(all_scores)

with open("_sd15_original_" + args.model_name + "_HPS_", "a") as f:
    # Write the stats
    #print(f"Beta {args.Beta/args.Lambda}", file=f)
    print(f"Median: {median_value}", file=f)
    print(f"Mean:   {mean_value}", file=f)
    # Two blank lines
    print(file=f)
    print(file=f)


# to get partiprompts captions
#from datasets import load_dataset
#dataset = load_dataset("nateraw/parti-prompts")
#print(dataset['train']['Prompt'])
