import torch
from PIL import Image
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
import os
from tqdm import tqdm
import argparse
import re
from pytorch_lightning import seed_everything
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def compute_clip_score(image, prompt, model, processor, device):
    inputs = processor(
        text=[prompt],
        images=image,
        return_tensors="pt",
        padding=True
    )
    
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        img_features = model.get_image_features(inputs['pixel_values'])
        img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
        
        txt_features = model.get_text_features(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask']
        )
        txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
        
        clip_score = 100 * (img_features * txt_features).sum(dim=-1)
        clip_score = torch.maximum(clip_score, torch.zeros_like(clip_score))
        
    return clip_score.cpu().item()

def calculate_clip_scores(image_folder: str, target_filename: str, csv_file: str, device: str, prompt_filter: bool) -> pd.DataFrame:
    # Set random seed for reproducibility
    seed_everything(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    if "unsafe" in target_filename:
        prompt_type = "unsafe"
    elif "safe" in target_filename:
        prompt_type = "safe"
    elif "mma" in target_filename:
        prompt_type = "mma"
    elif "sneaky" in target_filename:
        prompt_type = "sneaky"
    elif "coco" in target_filename:
        prompt_type = "coco"
    elif "i2p" in target_filename:
        prompt_type = "i2p"
    elif "p4d" in target_filename:
        prompt_type = "p4d"
    elif "ringabell" in target_filename:
        prompt_type = "ringabell"
    elif "clean" in target_filename:
        prompt_type = "clean"
    elif "nudity" in target_filename:
        prompt_type = "nudity"
    else:
        prompt_type = "nudity"
    
    # Extract UNet epoch information
    unet_epoch = None
    if "unet" in target_filename:
        match = re.search(r'unet(\d+)ep', target_filename)
        if match:
            unet_epoch = match.group(1)
    
    if "ori" in target_filename:
        mode = "ori"
    elif "none" in target_filename:
        mode = "none"
    elif "None" in target_filename:
        mode = "None"
    elif "promptfiltering" in target_filename:
        mode = "promptfiltering"
    elif "stringfiltering" in target_filename:
        mode = "stringfiltering"
    elif "classifier" in target_filename:
        mode = "classifier"
    elif "guardt2i" in target_filename:
        mode = "guardt2i"
    elif "latentguard" in target_filename:
        mode = "latentguard"
    elif "sdtnp" in target_filename:
        mode = "sdtnp"
    elif "_tnp" in target_filename:
        mode = "_tnp"
    elif "sdnp" in target_filename:
        mode = "sdnp"
    elif "_np" in target_filename:
        mode = "_np"
    elif "sld" in target_filename:
        mode = "sld"
    elif "safree" in target_filename:
        mode = "safree"
    elif "uce" in target_filename:
        mode = "uce"
    elif "esd" in target_filename:
        mode = "esd"
    elif "spm" in target_filename:
        mode = "spm"
    elif "fmn" in target_filename:
        mode = "fmn"
    elif "salun" in target_filename:
        mode = "salun"
    elif "nudity" in target_filename:
        mode = "nudity"
    elif "advunlearn" in target_filename:
        mode = "advunlearn"
    elif "visu" in target_filename:
        mode = "visu"
    elif "fft" in target_filename:
        mode = "fft"
    elif "pt" in target_filename:
        mode = "pt"
    elif "des" in target_filename:
        mode = "des"
    else:
        mode = "none"
    
    model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    
    # Ensure model is in eval mode
    model.eval()
    
    clip_scores = []
    
    if csv_file:
        df = pd.read_csv(csv_file)
        prompts = df['prompt'].tolist() if 'prompt' in df.columns else df.iloc[:, 0].tolist()
    else:
        raise ValueError("CSV file is required to determine prompts")
    
    image_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.png')])
    print(f"Found {len(image_files)} image files in directory.")

    for image_file in tqdm(image_files):
        match = re.search(r'(\d+)p_', image_file)
        if not match:
            match = re.search(r'(\d+)p\.', image_file)
        
        if not match:
            print(f"Warning: Cannot extract prompt index from filename: {image_file}. Skipping file.")
            continue
        
        prompt_idx = int(match.group(1))
        csv_idx = prompt_idx - 2
        
        current_unet_epoch = None
        if "unet" in image_file:
            unet_match = re.search(r'unet(\d+)ep', image_file)
            if unet_match:
                current_unet_epoch = unet_match.group(1)
        
        if unet_epoch and current_unet_epoch != unet_epoch:
            continue
        
        image_path = os.path.join(image_folder, image_file)
        image = Image.open(image_path)
        
        prompt = prompts[csv_idx]
        
        if prompt_filter:
            with open('google_words.txt', 'r') as f:
                target_tokens = [line.strip() for line in f]
                    
            filtered_prompt = prompt
            for token in target_tokens:
                filtered_prompt = filtered_prompt.replace(token, '')
            prompt = ' '.join(filtered_prompt.split())

        clip_score = compute_clip_score(image, prompt, model, processor, device)
        
        clip_scores.append({
            'file_name': image_file,
            'prompt': prompt,
            'clip_score': clip_score,
            'prompt_idx': prompt_idx
        })
    
    if not clip_scores:
        raise ValueError("No valid images were processed. Check your image folder and CSV file.")
        
    results_df = pd.DataFrame(clip_scores)
    
    output_path = f'clip_scores_{target_filename}.csv'
    results_df = results_df.sort_values(by='file_name')
    results_df.to_csv(output_path, index=False)
    
    avg_clip_score = results_df['clip_score'].mean()
    
    log_filename = f"clipscore_{target_filename}.log"
    with open(log_filename, 'a') as f:
        f.write(f"Average CLIP Score for {prompt_type}_{mode}: {avg_clip_score:.4f}\n")
        f.write(f"Processed {len(clip_scores)} images out of {len(prompts)} prompts.\n")
    
    print(f"Average CLIP Score for {prompt_type}_{mode}: {avg_clip_score:.4f}")
    print(f"Processed {len(clip_scores)} images out of {len(prompts)} prompts.")
    
    return results_df

def main():
    parser = argparse.ArgumentParser(description='Calculate CLIP scores for generated images')
    parser.add_argument('--image_folder', type=str, required=True,
                        help='Path to the folder containing generated images')
    parser.add_argument('--target_filename', type=str, required=True,
                        help='Target filename to extract patterns from')
    parser.add_argument('--csv_file', type=str, default=None,
                        help='csv file to read text prompts')
    parser.add_argument('--prompt_filter', action='store_true',
                        help='Filter prompts using google_words.txt')
    parser.add_argument('--device', type=str, default='cuda:0',
                        help='Device to run the model on (e.g., "cuda" or "cpu")')
    
    args = parser.parse_args()
    
    results = calculate_clip_scores(args.image_folder, args.target_filename, args.csv_file, args.device, args.prompt_filter)

# 사용 예시
if __name__ == "__main__":
    main()