import argparse
from rtpt import RTPT
import os
import clip
import torch
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from PIL import Image
import pandas as pd

@torch.no_grad()
def get_alignment(img_folder, name, num_samples, prompts):
    args = argparse.Namespace(folder=img_folder, name=name, num_samples=num_samples, prompts=prompts)
    return compute_similarity_scores(args)

@torch.no_grad()
def compute_similarity_scores(args):
    image_folder = args.folder
    image_files = sorted(os.listdir(image_folder))
    
    torch.set_num_threads(4)
    
    # load csv file
    df = pd.read_csv(args.prompts, sep=';')
    
#     assert len(image_files) // args.num_samples == len(df) , 'Number of images in the folder and number of prompts should be the same'
    
    model, preprocess = clip.load("ViT-B/32", device='cuda')

    rtpt = RTPT(args.name, 'Alignment score', len(df))
    rtpt.start()
    
    alignment_scores = []
    alignment_scores_vm = []
    alignment_scores_tm = []
    
    alignment_scores_max = []
    alignment_scores_vm_max = []
    alignment_scores_tm_max = []

    
    for id, row in tqdm(enumerate(df.iterrows()), total=len(df)):
        try:
            # load images
            imgs = []
            for sample_num in range(args.num_samples):
                img = Image.open(os.path.join(image_folder, f'img_{id:04d}_{sample_num:02d}.jpg'))
                img = preprocess(img).unsqueeze(0).to('cuda')
                imgs.append(img)
            imgs = torch.cat(imgs, dim=0)
            
            # compute embeddings
            image_features = model.encode_image(imgs)
            text = clip.tokenize([row[1]['Caption']]).to('cuda')
            text_features = model.encode_text(text)
            
            # compute similarity score
            similarity_score = cosine_similarity(image_features, text_features).cpu()
            alignment_scores.append(similarity_score.median())
            alignment_scores_max.append(similarity_score.max())
            
            if 'type' in row[1]:
                if row[1]['type'] == 'VM':
                    alignment_scores_vm.append(similarity_score.median())
                    alignment_scores_vm_max.append(similarity_score.max())
                elif row[1]['type'] == 'TM':
                    alignment_scores_tm.append(similarity_score.median())
                    alignment_scores_tm_max.append(similarity_score.max())
                else:
                    print(f'Invalid memorization type {row[1]["type"]}')  
            else:
                print('No memorization type provided')  
        except Exception as e:
            # print(f'Error processing image {id}: {e}')
            continue                   
        rtpt.step()
    
    result_dict = {}
    # compute statistics over the whole set
    alignment_scores = torch.stack(alignment_scores)
    median = alignment_scores.median().item()
    deviation = (alignment_scores - median).abs().median().item()
    
    alignment_scores_max = torch.stack(alignment_scores_max)
    median_max = alignment_scores_max.median().item()
    deviation_max = (alignment_scores_max - median_max).abs().median().item()
    
    result_dict['overall_median'] = {'median': median, 'deviation': deviation}
    result_dict['overall_max'] = {'median_max': median, 'deviation': deviation_max}
    
    print(f'Similarity score (Max, All): {median_max:.2f}\pm {deviation_max:.2f}')
    # print(f'Similarity score (Median, All): {median:.2f}\pm {deviation:.2f}')
    
    # compute statistics over VM samples
    if len(alignment_scores_vm) > 0:
        alignment_scores_vm = torch.stack(alignment_scores_vm)
        median_vm = alignment_scores_vm.median().item()
        deviation_vm = (alignment_scores_vm - median_vm).abs().median().item()
        
        alignment_scores_vm_max = torch.stack(alignment_scores_vm_max)
        median_vm_max = alignment_scores_vm_max.median().item()
        deviation_vm_max = (alignment_scores_vm_max - median_vm_max).abs().median().item()
        
        result_dict['vm_median'] = {'median': median_vm, 'deviation': deviation_vm}
        result_dict['vm_max'] = {'median': median_vm_max, 'deviation': deviation_vm_max}

        print(f'Similarity score (Max, VM) for {len(alignment_scores_vm)} samples: {median_vm_max:.2f}\pm {deviation_vm_max:.2f}')
        # print(f'Similarity score (Median VM) for {len(alignment_scores_vm)} samples: {median_vm:.2f}\pm {deviation_vm:.2f}')
    
    # compute statistics over TM samples
    if len(alignment_scores_tm) > 0:
        alignment_scores_tm = torch.stack(alignment_scores_tm)
        median_tm = alignment_scores_tm.median().item()
        deviation_tm = (alignment_scores_tm - median_tm).abs().median().item()
        
        alignment_scores_tm_max = torch.stack(alignment_scores_tm_max)
        median_tm_max = alignment_scores_tm_max.median().item()
        deviation_tm_max = (alignment_scores_tm_max - median_tm_max).abs().median().item()
        
        result_dict['tm_median'] = {'median': median_tm, 'deviation': deviation_tm}
        result_dict['tm_max'] = {'median': median_tm_max, 'deviation': deviation_tm_max}

        print(f'Similarity score (Max, TM) for {len(alignment_scores_tm)} samples: {median_tm_max:.2f}\pm {deviation_tm_max:.2f}')
        # print(f'Similarity score (Median TM) for {len(alignment_scores_tm)} samples: {median_tm:.2f}\pm {deviation_tm:.2f}')

    return result_dict

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--folder', type=str, help='Folder containing the images')
    parser.add_argument('-p', '--prompts', type=str, help='csv file containing the prompts')
    parser.add_argument('-n', '--name', default='XX',  type=str,help='RTPT user name (Default: XX)')
    parser.add_argument('--num_samples', default=10, type=int, help='Number of samples per prompt to compute the alignment score (Default: 10)')
    
    args = parser.parse_args()
    
    compute_similarity_scores(args)