import argparse
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
from torchvision import transforms
from PIL import Image
import os
import glob
import re
import numpy as np
import gc
from transformers import BlipProcessor, BlipForConditionalGeneration
import json

def extract_prompt_number(filename):
    """Extract prompt number (Np) from generated image filename"""
    pattern = r'_(\d+)p_'
    match = re.search(pattern, filename)
    if match:
        return int(match.group(1))
    return None

def extract_prompt_number_tmp(filename):
    """Extract prompt number (Np) from generated image filename"""
    pattern = r'_(\d+)p'
    match = re.search(pattern, filename)
    if match:
        return int(match.group(1))
    return None

def find_matching_gen_image(original_csv_index, gen_imgs_dir):
    """
    Find generated image matching the given index in the original CSV file
    
    Args:
        original_csv_index: 0-based index in the original COCO CSV file
        gen_imgs_dir: Directory containing generated images
    
    Returns:
        Path to the matching generated image or None if not found
    """
    all_files = glob.glob(os.path.join(gen_imgs_dir, "*.png"))
    
    prompt_num = original_csv_index + 2
    
    for file_path in all_files:
        filename = os.path.basename(file_path)
        extracted_num = extract_prompt_number(filename)
        if extracted_num is None:        
            extracted_num = extract_prompt_number_tmp(filename)
        if extracted_num == prompt_num:
            return file_path
    
    return None

def calculate_pickscore(gen_imgs_dir, prompts, case_numbers, original_indices, pick_model, pick_processor, device):
    """Calculate PickScore between generated images and prompts using proper matching"""
    model = pick_model
    processor = pick_processor
    
    pick_scores = []
    valid_pairs = 0
    
    for i, (prompt, case_number, original_idx) in enumerate(zip(prompts, case_numbers, original_indices)):
        # Find matching generated image using the original CSV index
        gen_img_path = find_matching_gen_image(original_idx, gen_imgs_dir)
        
        prompt_num = original_idx + 2  # CSV row -> prompt number
        
        if gen_img_path and os.path.exists(gen_img_path):
            image = Image.open(gen_img_path)
            
            # Preprocess image and text
            image_inputs = processor(
                images=[image],
                padding=True,
                truncation=True,
                max_length=77,
                return_tensors="pt",
            ).to(device)
            
            text_inputs = processor(
                text=[prompt],
                padding=True,
                truncation=True,
                max_length=77,
                return_tensors="pt",
            ).to(device)
            
            # Calculate PickScore
            with torch.no_grad():
                # Get embeddings
                image_embs = model.get_image_features(**image_inputs)
                image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
                
                text_embs = model.get_text_features(**text_inputs)
                text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
                
                # Calculate score
                score = model.logit_scale.exp() * (text_embs @ image_embs.T)[0][0]
                pick_scores.append(score.item())
                valid_pairs += 1
                
                if i % 10000 == 0:
                    print(f"PickScore for prompt {prompt_num}p ({i+1}/{len(prompts)}): {score.item():.4f}")
        else:
            print(f"Warning: Matching generated image not found for prompt {prompt_num}p (original idx {original_idx})")
    
    if pick_scores:
        average_pick_score = sum(pick_scores) / len(pick_scores)
        return average_pick_score, pick_scores, valid_pairs
    else:
        return 0, [], 0

def calculate_imagereward(gen_imgs_dir, prompts, case_numbers, original_indices, imagereward_model):
    """Calculate ImageReward score between generated images and prompts using proper matching"""
    model = imagereward_model
    
    reward_scores = []
    valid_pairs = 0
    
    for i, (prompt, case_number, original_idx) in enumerate(zip(prompts, case_numbers, original_indices)):
        # Find matching generated image using the original CSV index
        gen_img_path = find_matching_gen_image(original_idx, gen_imgs_dir)
        
        prompt_num = original_idx + 2  # CSV row -> prompt number
        
        if gen_img_path and os.path.exists(gen_img_path):
            # ImageReward expects a list of image paths, even if only one image
            image_paths = [gen_img_path]
            
            # Calculate reward score
            try:
                # model.score expects (prompt, [img_path1, img_path2, ...])
                score = model.score(prompt, image_paths)
                reward_scores.append(score)
                valid_pairs += 1
                
                if i % 10000 == 0:
                    print(f"ImageReward score for prompt {prompt_num}p ({i+1}/{len(prompts)}): {score:.4f}")
            except Exception as e:
                print(f"Error calculating ImageReward for prompt {prompt_num}p: {e}")
        else:
            print(f"Warning: Matching generated image not found for prompt {prompt_num}p (original idx {original_idx})")
    
    if reward_scores:
        average_reward_score = sum(reward_scores) / len(reward_scores)
        return average_reward_score, reward_scores, valid_pairs
    else:
        return 0, [], 0

def calculate_blipscore(gen_imgs_dir, prompts, case_numbers, original_indices, clip_model, clip_processor, blip_model, blip_processor, device):
    """Calculate BLIP score between generated images and prompts using transformers models"""
    blip_scores = []
    blip_captions = []
    valid_pairs = 0
    
    for i, (prompt, case_number, original_idx) in enumerate(zip(prompts, case_numbers, original_indices)):
        if i % 10 == 0:
            torch.cuda.empty_cache()
            
        # Find matching generated image using the original CSV index
        gen_img_path = find_matching_gen_image(original_idx, gen_imgs_dir)
        
        prompt_num = original_idx + 2
        
        if gen_img_path and os.path.exists(gen_img_path):
            try:
                image = Image.open(gen_img_path).convert("RGB")
                
                with torch.no_grad():
                    inputs = blip_processor(images=image, return_tensors="pt").to(device)
                    generated_ids = blip_model.generate(**inputs, max_length=50)
                    blip_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                    blip_captions.append(blip_caption)
                
                with torch.no_grad():
                    prompt_inputs = clip_processor(text=[prompt], return_tensors="pt", padding=True).to(device)
                    prompt_outputs = clip_model.get_text_features(**prompt_inputs)
                    prompt_embedding = prompt_outputs / prompt_outputs.norm(dim=1, keepdim=True)
                    
                    caption_inputs = clip_processor(text=[blip_caption], return_tensors="pt", padding=True).to(device)
                    caption_outputs = clip_model.get_text_features(**caption_inputs)
                    caption_embedding = caption_outputs / caption_outputs.norm(dim=1, keepdim=True)
                
                similarity = torch.nn.functional.cosine_similarity(prompt_embedding, caption_embedding).item()
                blip_scores.append(similarity)
                valid_pairs += 1
                
                if i % 10000 == 0:
                    print(f"BLIP Score for prompt {prompt_num}p ({i+1}/{len(prompts)}): {similarity:.4f}")
                    print(f"  Prompt: {prompt[:50]}...")
                    print(f"  Caption: {blip_caption}")
                
                del prompt_inputs, prompt_outputs, prompt_embedding
                del caption_inputs, caption_outputs, caption_embedding
                del inputs, generated_ids
                torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"Error calculating BLIP Score for prompt {prompt_num}p: {str(e)}")
                import traceback
                traceback.print_exc()
        else:
            print(f"Warning: Matching generated image not found for prompt {prompt_num}p (original idx {original_idx})")
    
    if blip_scores:
        average_blip_score = sum(blip_scores) / len(blip_scores)
        return average_blip_score, blip_scores, blip_captions, valid_pairs
    else:
        return 0, [], [], 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
                    prog = 'FID_Eval',
                    description = 'Evaluate FID score')
    
    parser.add_argument('--job', help='calculate CLIP score or FID', type=str, required=False, default='fid', 
                       choices=['fid', 'pickscore', 'imagereward', 'blipscore'])
    parser.add_argument('--gen_imgs_path', help='generated image folder for evaluation', type=str, required=True)
    parser.add_argument('--coco_imgs_path', help='coco real image folder for evaluation', type=str, required=False, default='datasets/coco_10k')
    parser.add_argument('--prompt_path', help='prompt for clip score', type=str, required=False, default='datasets/coco_10k.csv')
    parser.add_argument('--classify_prompt_path', help='prompt for classification', type=str, required=False, default='data/prompts/imagenette_5k.csv')
    parser.add_argument('--devices', help='cuda devices to train on', type=str, required=False, default='0,0')
    parser.add_argument('--kid_subset_size', help='subset size for KID calculation', type=int, required=False, default=1000)
    parser.add_argument('--kid_degree', help='polynomial kernel degree for KID', type=int, required=False, default=3)
    parser.add_argument('--output_file', help='output file name', type=str, required=False, default=None)
    
    args = parser.parse_args()
    devices = [f'cuda:{int(d.strip())}' for d in args.devices.split(',')]
    
    if args.output_file is None:
        args.output_file = args.gen_imgs_path + f'_{args.job}.txt'
    
    if args.job == 'fid':
        from T2IBenchmark import calculate_fid
        fid, _ = calculate_fid(args.gen_imgs_path, args.coco_imgs_path, device=devices[0])
        
        content = f'FID={fid}'
        file_path = args.output_file
        print(fid)
    
    elif args.job == 'pickscore':
        # Import necessary libraries for PickScore
        from transformers import AutoProcessor, AutoModel
        
        # Load PickScore model and processor
        processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
        model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
        
        processor = AutoProcessor.from_pretrained(processor_name_or_path)
        model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(devices[0])
        
        # load CSV
        csv_file_path = args.prompt_path
        df = pd.read_csv(csv_file_path)
        
        prompts = df['prompt'].tolist() if 'prompt' in df.columns else df[0].tolist()
        case_numbers = df['case_number'].tolist() if 'case_number' in df.columns else None
        
        if case_numbers is None:
            print("Warning: 'case_number' column not found in CSV. Using indices as case numbers.")
            case_numbers = [i for i in range(len(prompts))]
        
        original_indices = list(range(len(prompts)))
        
        print(f"Calculating PickScores for {len(prompts)} prompts...")
        average_pick_score, pick_scores, valid_pairs = calculate_pickscore(
            args.gen_imgs_path, 
            prompts, 
            case_numbers, 
            original_indices, 
            model, 
            processor, 
            devices[0]
        )
        
        content = f'Mean PickScore = {average_pick_score}\nValid image-prompt pairs: {valid_pairs}/{len(prompts)}'
        file_path = args.output_file
        
        if pick_scores:
            pick_scores = np.array(pick_scores)
            content += f'\nPickScore Min/Median/Max: {np.min(pick_scores):.4f}/{np.median(pick_scores):.4f}/{np.max(pick_scores):.4f}'
            content += f'\nPickScore Std: {np.std(pick_scores):.4f}'
    
    elif args.job == 'imagereward':
        # Import necessary libraries for ImageReward
        import ImageReward as RM
        
        # Load ImageReward model
        model = RM.load("ImageReward-v1.0", device=devices[0])
        
        # load CSV
        csv_file_path = args.prompt_path
        df = pd.read_csv(csv_file_path)
        
        prompts = df['prompt'].tolist() if 'prompt' in df.columns else df[0].tolist()
        case_numbers = df['case_number'].tolist() if 'case_number' in df.columns else None
        
        if case_numbers is None:
            print("Warning: 'case_number' column not found in CSV. Using indices as case numbers.")
            case_numbers = [i for i in range(len(prompts))]
        
        original_indices = list(range(len(prompts)))
        
        print(f"Calculating ImageReward scores for {len(prompts)} prompts...")
        average_reward, reward_scores, valid_pairs = calculate_imagereward(
            args.gen_imgs_path, 
            prompts, 
            case_numbers, 
            original_indices, 
            model
        )
        
        content = f'Mean ImageReward Score = {average_reward}\nValid image-prompt pairs: {valid_pairs}/{len(prompts)}'
        file_path = args.output_file
        
        if reward_scores:
            reward_scores = np.array(reward_scores)
            content += f'\nImageReward Score Min/Median/Max: {np.min(reward_scores):.4f}/{np.median(reward_scores):.4f}/{np.max(reward_scores):.4f}'
            content += f'\nImageReward Score Std: {np.std(reward_scores):.4f}'
    
    elif args.job == 'blipscore':
        torch.cuda.empty_cache()
        gc.collect()
        
        print("Loading CLIP model...")
        device = devices[0]
        clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        clip_model.eval()
        
        print("Loading BLIP model...")
        blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
        blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        blip_model.eval()
        
        csv_file_path = args.prompt_path
        df = pd.read_csv(csv_file_path)
        
        prompts = df['prompt'].tolist() if 'prompt' in df.columns else df[0].tolist()
        case_numbers = df['case_number'].tolist() if 'case_number' in df.columns else None
        
        if case_numbers is None:
            print("Warning: 'case_number' column not found in CSV. Using indices as case numbers.")
            case_numbers = [i for i in range(len(prompts))]
        
        original_indices = list(range(len(prompts)))
        
        print(f"Calculating BLIP scores for all {len(prompts)} prompts...")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        average_blip_score, blip_scores, blip_captions, valid_pairs = calculate_blipscore(
            args.gen_imgs_path,
            prompts,
            case_numbers,
            original_indices,
            clip_model,
            clip_processor,
            blip_model,
            blip_processor,
            device
        )
        
        content = f'Mean BLIP Score = {average_blip_score}\nValid image-prompt pairs: {valid_pairs}/{len(prompts)}'
        file_path = args.output_file
        
        if blip_scores:
            blip_scores = np.array(blip_scores)
            content += f'\nBLIP Score Min/Median/Max: {np.min(blip_scores):.4f}/{np.median(blip_scores):.4f}/{np.max(blip_scores):.4f}'
            content += f'\nBLIP Score Std: {np.std(blip_scores):.4f}'
            
            captions_data = {
                'average_score': average_blip_score,
                'valid_pairs': valid_pairs,
                'std': float(np.std(blip_scores)),
                'min': float(np.min(blip_scores)),
                'median': float(np.median(blip_scores)),
                'max': float(np.max(blip_scores))
            }
            json_file_path = args.output_file.replace('.txt', '_stats.json')
            with open(json_file_path, 'w', encoding='utf-8') as json_file:
                json.dump(captions_data, json_file, indent=2)
            print(f"BLIP score statistics saved to: {json_file_path}")
        
        torch.cuda.empty_cache()
        gc.collect()
    
    print(content)
    
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(content)
