import os
import torch
import torch.nn as nn
import pandas as pd
import csv
from PIL import Image
import open_clip
from urllib.request import urlretrieve
from os.path import expanduser
import argparse
import ast
from tqdm import tqdm
from transformers import AutoProcessor, AutoModel, CLIPModel, AutoTokenizer, AutoModelForCausalLM, CLIPModel, AutoModelForVision2Seq
import torch.nn.functional as F
from pathlib import Path
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
from pathlib import Path
import io
from datasets import load_dataset
from io import BytesIO
import sys

def get_clip_score(image1, image2, clip_model, processor, device):
    """Calculate CLIP similarity score between two images"""
    try:
        #process first image
        with torch.no_grad():
            inputs1 = processor(images=image1, return_tensors="pt").to(device)
            image_features1 = clip_model.get_image_features(**inputs1)
            

        #process second image
        with torch.no_grad():
            inputs2 = processor(images=image2, return_tensors="pt").to(device)
            image_features2 = clip_model.get_image_features(**inputs2)
            
        # Calculate similarity score
        cos = nn.CosineSimilarity(dim=0)
        sim = cos(image_features1[0],image_features2[0]).item()
        sim = (sim+1)/2
        return round(sim * 100, 2)
    
    except Exception as e:
        print(f"Error calculating CLIP score {str(e)}")
        return 0
    

def get_pic_score(prompt, images, model, processor, device):
     image_inputs = processor(
        images=images,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)
    
    text_inputs = processor(
        text=prompt,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        image_embs = model.get_image_features(**image_inputs)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
    
        text_embs = model.get_text_features(**text_inputs)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
    
        scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
    
    return scores.cpu().tolist()
   
def setup_vila_model(gpu_id):
    device = f'cuda:{gpu_id}'
    aes_model, preprocessor = convert_v2_5_from_siglip(
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    aes_model = aes_model.to(torch.bfloat16).to(device)
    return aes_model, preprocessor

def setup_pic_model(gpu_id):
    device = f'cuda:{gpu_id}'
    processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"

    processor = AutoProcessor.from_pretrained(processor_name_or_path)
    model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)
    return model, processor, device

def setup_laion_model(gpu_id, clip_model="vit_l_14"):
    device = f'cuda:{gpu_id}'
    model, preprocess, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
    model = model.to(device)
    
    home = expanduser("~")
    cache_folder = home + "/.cache/emb_reader"
    path_to_model = cache_folder + f"/sa_0_4_{clip_model}_linear.pth"
    
    if not os.path.exists(path_to_model):
        os.makedirs(cache_folder, exist_ok=True)
        url_model = f"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_{clip_model}_linear.pth?raw=true"
        urlretrieve(url_model, path_to_model)
    
    if clip_model == "vit_l_14":
        amodel = nn.Linear(768, 1)
    elif clip_model == "vit_b_32":
        amodel = nn.Linear(512, 1)
    else:
        raise ValueError("Unsupported model type")
        
    s = torch.load(path_to_model)
    amodel.load_state_dict(s)
    amodel = amodel.eval().to(device)
    
    return model, preprocess, amodel, device

def setup_clip_model(gpu_id):
    device = f'cuda:{gpu_id}'
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
    model = model.to(device)
    return model, processor,device

def calculate_scores(row, image_dir, base_dir, models,expname):
    vila_model, vila_processor = models['vila']
    pic_model, pic_processor, pic_device = models['pic']
    laion_model, laion_preprocess, laion_amodel ,laion_device= models['laion']
    clip_model, clip_processor,clip_device = models['clip']
    print("models loaded")
    scores = {}
    
    for temp in ['0.75','0.85','0.95']:
        image_path = os.path.join(image_dir,  row['filename'])
        base_image_path = os.path.join(base_dir, row['filename'])
        caption = ast.literal_eval(row['caption'])[0]
        
        print(image_path)
        
        # Load images
        image = Image.open(image_path).convert("RGB")
        base_image = Image.open(base_image_path).convert("RGB")
        
        # Calculate VILA score
        pixel_values = vila_processor(images=image, return_tensors="pt").pixel_values.to(torch.bfloat16).to(vila_model.device)
        with torch.inference_mode():
            vila_score = vila_model(pixel_values).logits.squeeze().float().cpu().numpy()
        print("vila:",round(vila_score * 10, 2))
        scores[f'vila_score_{expname}_{temp}'] = round(vila_score * 10, 2)
        
        #Calculate PIC score
        pic_score = get_pic_score(caption, [image], pic_model, pic_processor, pic_device)[0]
        print("pic:",pic_score)
        scores[f'pic_score_{expname}_{temp}'] = pic_score
        
        
        # Calculate LAION score
        laion_image = laion_preprocess(image).unsqueeze(0).to(laion_device)
        with torch.no_grad():
            image_features = laion_model.encode_image(laion_image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            prediction = laion_amodel(image_features)
            laion_score = float(prediction.cpu().numpy()[0][0])
        print("laion:", laion_score)
        scores[f'laion_score_{expname}_{temp}'] = laion_score 
        
        #Calculate CLIP similarity score
        clip_score = get_clip_score(image, base_image, clip_model, clip_processor, clip_device)
        print("clip:",clip_score )
        scores[f'clip_score_{expname}_{temp}'] = clip_score
       
    
    return scores
    
def process_dataset(image_dir, base_dir, scores_file, expname):
    # Check GPU availability
    if not torch.cuda.is_available():
        raise RuntimeError("No GPU available")
    
    print("Loading models on different GPUs...")
    
    # Load models on different GPUs
    models = {
        'vila': setup_vila_model(0),
        'pic': setup_pic_model(1),
        'laion': setup_laion_model(2),
        'clip': setup_clip_model(3)
    }
    
    logging.info("Models loaded successfully")
    
    # Read input CSV
    df = pd.read_csv(scores_file)
    
    
    # Prepare output CSV file
    output_file = scores_file.replace('.csv', '_scored.csv')
    
    print(output_file)
    
    # Get original columns
    original_columns = df.columns.tolist()
    
    # Define new score columns
    
    score_columns = [
        f'{model_type}_score_{expname}_{temp}'
        for model_type in ['vila','pic', 'laion', 'clip']
         for temp in ['']['0.75','0.85','0.95']]
    
    print(score_columns)
    # Open output file and create CSV writer
    
    with open(output_file, 'w', newline='') as csvfile:
        # Prepare CSV writer with all columns
        csv_writer = csv.writer(csvfile)
        
        # Write header
        header = original_columns + score_columns
        csv_writer.writerow(header)
        
        # Process images
        for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing images"):
            try:
                # Calculate scores for the row
                print(index)
                scores = calculate_scores(row, image_dir, base_dir, models, expname)
                 
                # Prepare row data with original row values and new scores
                row_data = list(row) + [
                    scores.get(col, None) for col in score_columns
                
                print([scores.get(col, None) for col in score_columns])
                # Write row to CSV
                csv_writer.writerow(row_data)
                # Flush to ensure writing
                csvfile.flush()
            except Exception as e:
                # Write row with empty scores if processing fails
                error_row_data = list(row) + [None] * len(score_columns)
                csv_writer.writerow(error_row_data)
                csvfile.flush()
    
    print(f"Scoring complete. Results saved to {output_file}")

# Rest of the code remains the same as in the original script
def main():
    parser = argparse.ArgumentParser(description='Calculate aesthetic scores for images')
    parser.add_argument('--image_dir', required=True, help='Directory containing the generated images')
    parser.add_argument('--base_dir', required=True, help='Directory containing the base images')
    parser.add_argument('--file', required=True, help='Path to the scores CSV file in which scores are to be added')
    parser.add_argument('--expname', required=True, help='Path to the scores CSV file')
    args = parser.parse_args()
    
    process_dataset(args.image_dir, args.base_dir, args.scores_file, args.expname)

if __name__ == "__main__":
    main()


#python3 aesthetics_scorer.py --image_dir it1-sdxl-images/ --base_dir base_flickr_images_train/ --scores_file "input_it1_prompts.csv"  --expname  "it1-sdxl"

