import argparse
from rtpt import RTPT
import os
import clip
import torch
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torchmetrics.functional.pairwise import pairwise_cosine_similarity
import pandas as pd

# Download SSCD model first using 'wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt'

@torch.no_grad()
def get_diversity(img_folder, name, num_samples, prompts):
    args = argparse.Namespace(folder=img_folder, method='sscd', name=name, num_samples=num_samples, prompts=prompts)
    return compute_similarity_scores(args)

@torch.no_grad()    
def compute_similarity_scores(args):
    folder = args.folder
    
    torch.set_num_threads(4)
    
    files = sorted(os.listdir(folder))
        
    if args.method.lower() == 'clip':
        model, preprocess = clip.load("ViT-B/32", device='cuda')
    elif args.method.lower() == 'sscd':
        model = torch.jit.load("sscd_disc_mixup.torchscript.pt").cuda()
        preprocess = transforms.Compose([
            transforms.Resize([320, 320]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        raise ValueError('Invalid method. Select one of [clip, sscd]')

    rtpt = RTPT(args.name, 'Sim score', len(files))
    rtpt.start()
    
    if 'prompts' in args:
        df = pd.read_csv(args.prompts, sep=';')
    else:
        print('No prompt file provided.')


    sim_scores = []
    sim_scores_vm = []
    sim_scores_tm = []
    
    sim_scores_max = []
    sim_scores_vm_max = []
    sim_scores_tm_max = []

    for id in tqdm(range(len(files) // args.num_samples)):
        try:
        # load images
            imgs = []
            for sample_num in range(args.num_samples):
                img = Image.open(os.path.join(folder, f'img_{id:04d}_{sample_num:02d}.jpg'))
                img = preprocess(img).unsqueeze(0).to('cuda')
                imgs.append(img)
            
            imgs = torch.cat(imgs, dim=0)
        
            # compute embeddings
            if args.method == 'clip':
                embeddings = model.encode_image(imgs)
            elif args.method == 'sscd':
                embeddings = model(imgs)
            
            # compute similarity score
            similarity_score = pairwise_cosine_similarity(embeddings).cpu()
            
            # get sim values of lower triangular
            tril_mask = torch.tril(torch.ones(args.num_samples, args.num_samples), diagonal=-1).bool()
            similarity_score = similarity_score[tril_mask]
            
            # log the similarity score
            sim_scores.append(similarity_score.median())
            sim_scores_max.append(similarity_score.max())
            
            if 'prompts' in args:
                if 'type' in df.iloc[id]:
                    if df.iloc[id]['type'] == 'VM':
                        sim_scores_vm.append(similarity_score.median())
                        sim_scores_vm_max.append(similarity_score.max())
                    elif df.iloc[id]['type'] == 'TM':
                        sim_scores_tm.append(similarity_score.median())
                        sim_scores_tm_max.append(similarity_score.max())
                    else:
                        print(f'Invalid memorization type {df.iloc[id]["type"]}')  
                else:
                    print('No memorization type provided')  
        except Exception as e:
            continue
        rtpt.step()

    result_dict = {}
    
    sim_scores_max = torch.stack(sim_scores_max)
    median_max = sim_scores_max.median().item()
    deviation_max = (sim_scores_max - median_max).abs().median().item()
    result_dict['overall_max'] = {'median': median_max, 'deviation': deviation_max}

    sim_scores = torch.stack(sim_scores)
    median = sim_scores.median().item()
    deviation = (sim_scores - median).abs().median().item()
    result_dict['overall_median'] = {'median': median, 'deviation': deviation}

    print(f'Diversity score (Max, All): {median_max:.2f}\pm {deviation_max:.2f}')
    # print(f'Diversity score (Median, All): {median:.2f}\pm {deviation:.2f}', '\n')

    # compute statistics over VM samples
    if len(sim_scores_vm) > 0:
        sim_scores_vm = torch.stack(sim_scores_vm)
        median_vm = sim_scores_vm.median().item()
        deviation_vm = (sim_scores_vm - median_vm).abs().median().item()
        
        sim_scores_vm_max = torch.stack(sim_scores_vm_max)
        median_vm_max = sim_scores_vm_max.median().item()
        deviation_vm_max = (sim_scores_vm_max - median_vm_max).abs().median().item()
        
        result_dict['vm_median'] = {'median': median_vm, 'deviation': deviation_vm}
        result_dict['vm_max'] = {'median': median_vm_max, 'deviation': deviation_vm_max}
        
        print(f'Diversity score (Max, VM): {median_vm_max:.2f}\pm {deviation_vm_max:.2f}')
        # print(f'Diversity score (Median, VM): {median_vm:.2f}\pm {deviation_vm:.2f}', '\n')
    
    # compute statistics over TM samples
    if len(sim_scores_tm) > 0:
        sim_scores_tm = torch.stack(sim_scores_tm)
        median_tm = sim_scores_tm.median().item()
        deviation_tm = (sim_scores_tm - median_tm).abs().median().item()
        
        sim_scores_tm_max = torch.stack(sim_scores_tm_max)
        median_tm_max = sim_scores_tm_max.median().item()
        deviation_tm_max = (sim_scores_tm_max - median_tm_max).abs().median().item()
        
        result_dict['tm_median'] = {'median': median_tm, 'deviation': deviation_tm}
        result_dict['tm_max'] = {'median': median_tm_max, 'deviation': deviation_tm_max}
        
        print(f'Diversity score (Max, TM): {median_tm_max:.2f}\pm {deviation_tm_max:.2f}')
        # print(f'Diversity score (Median, TM): {median_tm:.2f}\pm {deviation_tm:.2f}')

    return result_dict

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--folder', type=str, help='Folder containing the images')
    parser.add_argument('-m', '--method', default='sscd', type=str, help='Method to compute the similarity score. Select one of [clip, sscd]. (Default: sscd)')
    parser.add_argument('-n', '--name', default='XX',  type=str,help='RTPT user name (Default: XX)')
    parser.add_argument('--num_samples', default=10, type=int, help='Number of samples per prompt to compute the similarity score (Default: 10)')
    parser.add_argument('--mem_type', default=None, type=str, help='Type of memorization (Default: None)')
    parser.add_argument('-p', '--prompts', type=str, help='csv file containing the prompts')

    args = parser.parse_args()
    
    compute_similarity_scores(args)