from steering import SourceTemperingSampler
import torch 
import datetime
import time
import os 
import diffusers, transformers
import json 
import csv


beta=10.0
n_chains=3
device="cuda"
batch_size=4



#PAPER PROMPTS: 
# dog_prompt="a photo of a green tennis racket and a black dog"

# prompt = "a photo of an orange cow with a purple sandwich"
# prompt_subject = "cow"


# bird_prompt = "a photo of a black motorbike and a yellow bird"
# prompt_subject = "bird"

# knife_prompt = "a photo of a brown knife and blue donut"
# prompt_subject = "knife"

# prompt = "a photo of a blue clock and a white cup"
# prompt_subject = "clock"

# prompt = "a green stop sign in a red field"
# prompt_subject = "green_stop_sign"




with open('json_files/partial_prompts.json', 'r', encoding='utf-8') as file:
    data = json.load(file)



for item in data:
    prompt = item['prompt']
    prompt_subject = item['id']


    # num_spt_steps = 50
    num_spt_steps = 30

    model = "xl"

    best_image_dir = f'IR/{num_spt_steps}/{model}/best_images/'
    all_image_dir = f'IR/{num_spt_steps}/{model}/all_images/'
    ir_score_path = f'{best_image_dir}ir_scores.txt'

    os.makedirs(best_image_dir, exist_ok=True)
    os.makedirs(all_image_dir, exist_ok=True)

    best_image_path = f'{best_image_dir}{prompt_subject}_best_img_k_{batch_size}_chains_{n_chains}_.jpg'
    # tiled_path = f'{best_image_dir}{prompt_subject}_tiled_img_k_{
    csv_path = f'{best_image_dir}ir_scores.csv'

    file_exists = os.path.isfile(csv_path)


    sampler = SourceTemperingSampler(
        beta=beta,
        n_chains=n_chains,
        device=device,
        text_prompt=prompt, 
        batch_size=batch_size,
        model = model
    )



    # hp = HPScorer(prompt=prompt, device=device)

    # Run sampler
    print(f'Entering sampling for prompt: {prompt} ')
    # print(f'Saving to: {gen_image_path} ')
    cold_chain, best_images, rewards = sampler.sample(num_spt_steps)



    best_img_index = torch.argmax(rewards).item()


    best_images[0][best_img_index].save(best_image_path)

    for i in range(len(best_images[0])):
        all_image_path = f'{all_image_dir}{prompt_subject}_img_{i}_k_{batch_size}_chains_{n_chains}_.jpg'
        best_images[0][i].save(all_image_path)

    # hps_score = hpsv2.score(best_images[0][best_img_index], prompt, hps_version = "v2.1")

    # hp_score = hpsv2.score(images=[best_images[0][best_img_index]], prompts=[prompt], hps_version = "v2.1").to(device)
    # hp_score = hp.score_images(image=gen_image_path)

    best_img_max = torch.max(rewards).item()


    print(f'N CHAINS: {n_chains}')
    print(f'MAX BETA: {beta}')
    print(f'IR SCORE: {best_img_max}')
    with open(csv_path, 'a', newline='') as f:
        writer = csv.writer(f)

        # write header once
        if not file_exists:
            writer.writerow([
                'prompt_id',
                'prompt',
                'model',
                'num_spt_steps',
                'n_chains',
                'beta',
                'batch_size',
                'best_image',
                'ir_score'
            ])

        writer.writerow([
            prompt_subject,
            prompt,
            model,
            num_spt_steps,
            n_chains,
            beta,
            batch_size,
            os.path.basename(best_image_path),
            best_img_max
        ])
    # print(f'HPS SCORE: {hps_score[0]}')


