from steering import SourceTemperingSampler
import torch 
# from plotting import tile_images
import datetime
import time
# from reward import HPScorer
import os 
# from huggingface_hub import whoami
import diffusers, transformers
import json 
import csv
import argparse

beta=10.0
n_chains=3
device="cuda"
batch_size=4


with open("json_files/prompts.json", "r", encoding="utf-8") as file:
    data = json.load(file)


parser = argparse.ArgumentParser()

parser.add_argument(
    "--id", 
    help="please indicate your prompt id",
    required=True
    # choices=["Mixture", "Circle"]
)


parser.add_argument(
    "--steps", 
    help="please indicate your prompt id",
    required=True, 
    type=int
    # choices=["Mixture", "Circle"]
)

args = parser.parse_args()


def get_prompt_by_id(data, target_id):
    for item in data:
        if item["id"] == target_id:
            return item["prompt"]
    return None

prompt = get_prompt_by_id(data, args.id)

prompt_subject = args.id

print(f'prompt id: {prompt_subject}')

num_spt_steps = args.steps

model = "v1.4"

# best_image_dir = f'IR/spt_{num_spt_steps}/{model}/best_images/'
all_image_dir = f'budget_4650_{model}/'


# all_image_dir = f'CLIP/{model}_{num_spt_steps}/'
# ir_score_path = f'{best_image_dir}ir_scores.txt'

# os.makedirs(best_image_dir, exist_ok=True)
os.makedirs(all_image_dir, exist_ok=True)

# best_image_path = f'{best_image_dir}{prompt_subject}_best_img_k_{batch_size}_chains_{n_chains}_.jpg'


sampler = SourceTemperingSampler(
    beta=beta,
    n_chains=n_chains,
    device=device,
    text_prompt=prompt, 
    batch_size=batch_size,
    model = model
)



cold_chain, best_images, rewards = sampler.sample(num_spt_steps)


for s in range(batch_size): 
    image_path = f'{all_image_dir}{prompt_subject}_img_{s}_k_4_chains_3.jpg'
    best_images[0][s].save(image_path)


    csv_path = f'{all_image_dir}{prompt_subject}_{s}_ir_scores.csv'


    with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)

            writer.writerow([f'ID: {prompt_subject}, IR: {rewards[s]}'])


    


