import os
import torch
import numpy as np
from pathlib import Path
from utils.clip_utils import Selector

def load_prompts(file_path):
    """Load prompts from text file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f]

def compare_images(prompts, baseline_dir, improved_dir, selector):
    """Compare images from two different models using CLIP"""
    wins = 0
    total = len(prompts)
    comparison_results = []
    
    for idx, prompt in enumerate(prompts):
        baseline_img = os.path.join(baseline_dir, f"img{idx}.jpg")
        improved_img = os.path.join(improved_dir, f"img{idx}.jpg")
        
        try:
            baseline_score = selector.score(img_path=baseline_img, prompt=prompt)[0]
            improved_score = selector.score(img_path=improved_img, prompt=prompt)[0]
            
            if baseline_score > improved_score:
                print(f"Image {idx}: Diffusion-DPO image is better than Diffusion-RainbowPA image.")
            else:
                print(f"Image {idx}: Diffusion-RainbowPA image is better than Diffusion-DPO image.")
                wins += 1
                
            comparison_results.append(idx)
        except Exception as e:
            print(f"Error processing image pair {idx}: {e}")
    
    return wins, total, comparison_results

def main():
    torch.set_grad_enabled(False)
    inference_dtype = torch.float16
    
    # Setup the path
    base_dir = Path("./evaluating")
    prompt_file = base_dir / "image_generation/Gen-AI/Gen-AI.txt"
    baseline_dir = base_dir / "image_generation/Gen-AI/Generation/Diffusion-DPO"
    improved_dir = base_dir / "image_generation/Gen-AI/Generation/RainbowPA"
    
    # Load prompts
    prompts = load_prompts(prompt_file)
    print(f"Loaded {len(prompts)} prompts.")
    
    # Initialization
    selector = Selector('cuda:0')
    
    # Compare images
    wins, total, comparison_results = compare_images(
        prompts, baseline_dir, improved_dir, selector
    )
    
    # Print results
    print(comparison_results)
    winning_rate = wins / total
    print(f"Diffusion-RainbowPA winning rate: {winning_rate:.4f}")

if __name__ == "__main__":
    main()