import os
import torch
import numpy as np
from pathlib import Path
from utils.clip_utils import Selector

def load_prompts(file_path):
    """Load prompts from text file"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f]

def compare_images(prompts, original_dir, improved_dir, selector):
    """Compare images from two different models using CLIP"""
    wins = 0
    total = len(prompts)
    comparison_results = []
    
    for idx, prompt in enumerate(prompts):
        original_img = os.path.join(original_dir, f"img{idx}.jpg")
        improved_img = os.path.join(improved_dir, f"img{idx}.jpg")
        
        try:
            original_score = selector.score(img_path=original_img, prompt=prompt)[0]
            improved_score = selector.score(img_path=improved_img, prompt=prompt)[0]
            
            if original_score > improved_score:
                print(f"Image {idx}: SD1-5 image is better than Diffusion-RainbowPA image")
            else:
                print(f"Image {idx}: Diffusion-RainbowPA image is better than SD1-5 image")
                wins += 1
                
            comparison_results.append(idx)
        except Exception as e:
            print(f"Error processing image pair {idx}: {e}")
    
    return wins, total, comparison_results

def main():
    torch.set_grad_enabled(False)
    inference_dtype = torch.float16
    
    # Setup the path
    base_dir = Path("./evaluating")
    prompt_file = base_dir / "image_generation/Gen-AI/Gen-AI.txt"
    original_dir = base_dir / "image_generation/Gen-AI/Generation/SD-1-5"
    improved_dir = base_dir / "image_generation/Gen-AI/Generation/RainbowPA"
    
    # Load prompts
    prompts = load_prompts(prompt_file)
    print(f"Loaded {len(prompts)} prompts")
    
    # Initialization
    selector = Selector('cuda:0')
    
    # Compare images
    wins, total, comparison_results = compare_images(
        prompts, original_dir, improved_dir, selector
    )
    
    # Print results
    print(comparison_results)
    winning_rate = wins / total
    print(f"Diffusion_RainbowPA winning rate: {winning_rate:.4f}")

if __name__ == "__main__":
    main()