from steering import SourceTemperingSampler
import torch 
# from plotting import tile_images
import datetime
import time
# from reward import HPScorer
import os 
# from huggingface_hub import whoami
import diffusers, transformers
from Mean_Flow_Steering.steer_stable_diffusion.random_scripts.tile import tile_images_with_best
import json 
import hpsv2 
import csv
import argparse

beta=10.0
n_chains=3
device="cuda"
batch_size=4


with open("prompts.json", "r", encoding="utf-8") as file:
    data = json.load(file)


parser = argparse.ArgumentParser()

parser.add_argument(
    "--id", 
    help="please indicate your prompt id",
    required=True
    # choices=["Mixture", "Circle"]
)


parser.add_argument(
    "--steps", 
    help="please indicate your prompt id",
    required=True, 
    type=int
    # choices=["Mixture", "Circle"]
)

args = parser.parse_args()


def get_prompt_by_id(data, target_id):
    for item in data:
        if item["id"] == target_id:
            return item["prompt"]
    return None

prompt = get_prompt_by_id(data, args.id)

prompt_subject = args.id


num_spt_steps = args.steps

model = "v1.5"

# best_image_dir = f'IR/spt_{num_spt_steps}/{model}/best_images/'
# all_image_dir = f'IR/abalation_test_({n_chains},{num_spt_steps})/'


all_image_dir = f'CLIP/{model}_{num_spt_steps}/'
# ir_score_path = f'{best_image_dir}ir_scores.txt'

# os.makedirs(best_image_dir, exist_ok=True)
os.makedirs(all_image_dir, exist_ok=True)

# best_image_path = f'{best_image_dir}{prompt_subject}_best_img_k_{batch_size}_chains_{n_chains}_.jpg'
# tiled_path = f'{best_image_dir}{prompt_subject}_tiled_img_k_{



sampler = SourceTemperingSampler(
    beta=beta,
    n_chains=n_chains,
    device=device,
    text_prompt=prompt, 
    batch_size=batch_size,
    model = model
)



# hp = HPScorer(prompt=prompt, device=device)

# Run sampler
# print(f'Entering sampling for prompt: {prompt} ')

cold_chain, best_images, rewards = sampler.sample(num_spt_steps)


for s in range(batch_size): 
    image_path = f'{all_image_dir}{prompt_subject}_{s}.png'
    best_images[0][s].save(image_path)


    # csv_path = f'{all_image_dir}{prompt_subject}_{s}_ir_scores.csv'


    # with open(csv_path, 'w', newline='') as f:
    #         writer = csv.writer(f)

    #         writer.writerow([f'ID: {prompt_subject}, IR: {rewards[s]}'])


    


