import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import clip
import torch

class CustomDataset(Dataset):
    def __init__(self, text_file, image_folder1, image_folder2, clip_model, clip_processor, device="cuda"):
        self.texts = self.load_texts(text_file)
        self.image_paths1 = self.load_image_paths(image_folder1)
        self.image_path2 = os.path.join(image_folder2, "1.jpg")  # Assuming the image is named "1.jpg"
        self.clip_model = clip_model.to(device)
        self.clip_model.eval()
        self.clip_processor = clip_processor
        self.device = device

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        image_path1 = self.image_paths1[idx]
        image_path2 = self.image_path2

        # Load image 1
        image1 = Image.open(image_path1).convert("RGB")
        # Use CLIP processor on the image
        image1 = self.clip_processor(image1)
        # Ensure the image has the desired shape [1, 3, 224, 224]
        image1 = image1.unsqueeze(0)

        # Load image 2
        image2 = Image.open(image_path2).convert("RGB")
        # Use CLIP processor on the image
        image2 = self.clip_processor(image2)
        # Ensure the image has the desired shape [1, 3, 224, 224]
        image2 = image2.unsqueeze(0)

        # Calculate CLIP text score
        text_score = self.calculate_clip_text_score(image1, text)
        # Calculate CLIP image score
        image_score = self.calculate_clip_image_score(image1, image2)
        
        return {"text": text, "image1": image1, "image2": image2, "text_score": text_score, "image_score": image_score}

    def load_texts(self, text_file):
        with open(text_file, 'r') as file:
            texts = [line.strip() for line in file]
        return texts

    def load_image_paths(self, image_folder):
        image_paths = [os.path.join(image_folder, f"{i:05d}.jpg") for i in range(len(self.texts))]
        return image_paths

    def calculate_clip_text_score(self, image, text):
        with torch.no_grad():
            text_token = clip.tokenize(text).to(self.device)
            text_features = self.clip_model.encode_text(text_token)
            text_features = torch.nn.functional.normalize(text_features, p=2, dim=1)  # Normalize text features
        
            image_features = self.clip_model.encode_image(image.to(self.device))
            image_features = torch.nn.functional.normalize(image_features, p=2, dim=1)  # Normalize image features
            
            similarity_score = torch.nn.functional.cosine_similarity(image_features, text_features)
        return similarity_score.item()

    def calculate_clip_image_score(self, image1, image2):
        with torch.no_grad():
            image_features1 = self.clip_model.encode_image(image1.to(self.device))
            image_features1 = torch.nn.functional.normalize(image_features1, p=2, dim=1)  # Normalize image features

            image_features2 = self.clip_model.encode_image(image2.to(self.device))
            image_features2 = torch.nn.functional.normalize(image_features2, p=2, dim=1)  # Normalize image features
            
            similarity_score = torch.nn.functional.cosine_similarity(image_features1, image_features2)
        return similarity_score.item()


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, clip_processor = clip.load("ViT-B/32", device)
    clip_model.eval()

    # Replace with your actual paths
    
    text_file = "/home/model/final_txt/db_normal_dog_prompt_50.txt"
    image_folder1 = "/home/model/final_results/corgi_dog_bloody_man/db_normal/samples"
    image_folder2 = "/home/model/datasets_eval/corgi_dog"

    dataset = CustomDataset(text_file, image_folder1, image_folder2, clip_model, clip_processor, device)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    total_text_score = 0.0
    total_image_score = 0.0
    text_scores = []
    image_scores = []
    total_samples = 0

    for batch in dataloader:
        text = batch["text"]
        image1 = batch["image1"]
        image2 = batch["image2"]
        text_score = batch["text_score"]
        image_score = batch["image_score"]

        total_text_score += text_score
        total_image_score += image_score
        text_scores.append(text_score)
        image_scores.append(image_score)
        total_samples += 1

        print(f"Text: {text}, Text Score: {text_score}, Image Score: {image_score}")

    if total_samples > 0:
        average_text_score = total_text_score / total_samples
        average_image_score = total_image_score / total_samples

        text_scores_tensor = torch.tensor(text_scores, device=device)
        image_scores_tensor = torch.tensor(image_scores, device=device)

        std_text_score = torch.std(text_scores_tensor)
        std_image_score = torch.std(image_scores_tensor)

        print(f"Average Text Score: {average_text_score}")
        print(f"Average Image Score: {average_image_score}")
        print(f"Standard Deviation of Text Score: {std_text_score}")
        print(f"Standard Deviation of Image Score: {std_image_score}")
    else:
        print("No samples in the dataset.")