import argparse
from rtpt import RTPT
import os
import clip
import torch
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
import pandas as pd

# Download SSCD model first using 'wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt'

@torch.no_grad()
def get_sscd_gen(gen_folder, ref_folder, name, num_samples, prompts):
    args = argparse.Namespace(folder=gen_folder, reference=ref_folder, method='sscd', name=name, num_samples=num_samples, prompts=prompts)
    return compute_similarity_scores(args)

@torch.no_grad()
def compute_similarity_scores(args):
    folder = args.folder
    ref_folder = args.reference
    
    torch.set_num_threads(4)

    files = sorted(os.listdir(folder))
    files_reference = sorted(os.listdir(ref_folder))
        
    # assert len(files) == len(files_reference), 'Number of images in the folder and reference folder should be the same'
    
    if 'prompts' in args and args.prompts is not None:
        df = pd.read_csv(args.prompts, sep=';')
    else:
        print('No prompt file provided.')
    
    if args.method == 'clip':
        model, preprocess = clip.load("ViT-B/32", device='cuda')
    elif args.method == 'sscd':
        model = torch.jit.load("sscd_disc_mixup.torchscript.pt").cuda()
        preprocess = transforms.Compose([
            transforms.Resize([320, 320]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        raise ValueError('Invalid method. Select one of [clip, sscd]')

    rtpt = RTPT(args.name, 'Sim score', len(files))
    rtpt.start()
    
    sim_scores = []
    sim_scores_max = []
    sim_scores_vm = []
    sim_scores_vm_max = []
    sim_scores_tm = []
    sim_scores_tm_max = []

    for id in tqdm(range(len(files) // args.num_samples)):
        # load images
        imgs = []
        references = []
        try:
            for sample_num in range(args.num_samples):
                img = Image.open(os.path.join(folder, f'img_{id:04d}_{sample_num:02d}.jpg')).convert('RGB')
                img = preprocess(img).unsqueeze(0).to('cuda')
                reference = Image.open(os.path.join(ref_folder, f'img_{id:04d}_{sample_num:02d}.jpg')).convert('RGB')
                reference = preprocess(reference).unsqueeze(0).to('cuda')
                imgs.append(img)
                references.append(reference)
        except:
            # print('Image corrupted: ', os.path.join(folder, f'img_{id:04d}_{sample_num:02d}.jpg'))
            continue
        
        imgs = torch.cat(imgs, dim=0)
        references = torch.cat(references, dim=0)
    
        # compute embeddings
        if args.method == 'clip':
            embeddings = model.encode_image(torch.cat([imgs, references], dim=0))
        elif args.method == 'sscd':
            embeddings = model(torch.cat([imgs, references], dim=0))
        
        # compute similarity score
        similarity_score = cosine_similarity(embeddings[:args.num_samples], embeddings[args.num_samples:]).cpu()
        sim_scores.append(similarity_score.median())
        sim_scores_max.append(similarity_score.max())
        
        if 'prompts' in args and args.prompts is not None:
            if 'type' in df.iloc[id]:
                if df.iloc[id]['type'] == 'VM':
                    sim_scores_vm.append(similarity_score.median())
                    sim_scores_vm_max.append(similarity_score.max())
                elif df.iloc[id]['type'] == 'TM':
                    sim_scores_tm.append(similarity_score.median())
                    sim_scores_tm_max.append(similarity_score.max())
                else:
                    print(f'Invalid memorization type {df.iloc[id]["type"]}')  
            else:
                print('No memorization type provided')  

            
                
        # log the similarity score
        rtpt.step()
        
    result_dict = {}
        
    sim_scores = torch.stack(sim_scores)
    sim_scores_max = torch.stack(sim_scores_max)
    median = sim_scores.median().item()
    median_max = sim_scores_max.median().item()
    deviation = (sim_scores - median).abs().median().item()
    deviation_max = (sim_scores_max - median_max).abs().median().item()
    
    result_dict['overall_median'] = {'median': median, 'deviation': deviation}
    result_dict['overall_max'] = {'median_max': median, 'deviation': deviation_max}
    
    print('Similarity score (Max, All): {:.2}\pm {:.2f}'.format(median_max, deviation_max))
    # print('Similarity score (Median, All): {:.2f}\pm {:.2f}'.format(median, deviation), '\n')
    
    # compute statistics over VM samples
    if len(sim_scores_vm) > 0:
        sim_scores_vm = torch.stack(sim_scores_vm)
        median_vm = sim_scores_vm.median().item()
        deviation_vm = (sim_scores_vm - median_vm).abs().median().item()
        
        sim_scores_vm_max = torch.stack(sim_scores_vm_max)
        median_vm_max = sim_scores_vm_max.median().item()
        deviation_vm_max = (sim_scores_vm_max - median_vm_max).abs().median().item()
        
        result_dict['vm_median'] = {'median': median_vm, 'deviation': deviation_vm}
        result_dict['vm_max'] = {'median': median_vm_max, 'deviation': deviation_vm_max}
        
        print('Similarity score (Max, VM): {:.2f}\pm {:.2f}'.format(median_vm_max, deviation_vm_max))
        # print('Similarity score (Median, VM): {:.2f}\pm {:.2f}'.format(median_vm, deviation_vm), '\n')
        
    # compute statistics over TM samples
    if len(sim_scores_tm) > 0:
        sim_scores_tm = torch.stack(sim_scores_tm)
        median_tm = sim_scores_tm.median().item()
        deviation_tm = (sim_scores_tm - median_tm).abs().median().item()
        
        sim_scores_tm_max = torch.stack(sim_scores_tm_max)
        median_tm_max = sim_scores_tm_max.median().item()
        deviation_tm_max = (sim_scores_tm_max - median_tm_max).abs().median().item()
        
        result_dict['tm_median'] = {'median': median_tm, 'deviation': deviation_tm}
        result_dict['tm_max'] = {'median': median_tm_max, 'deviation': deviation_tm_max}
        
        print('Similarity score (Max, TM): {:.2f}\pm {:.2f}'.format(median_tm_max, deviation_tm_max))
        # print('Similarity score (Median, TM): {:.2f}\pm {:.2f}'.format(median_tm, deviation_tm))

    return result_dict

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--folder', type=str, help='Folder 1 containing the images')
    parser.add_argument('-r', '--reference', type=str, help='Folder 2 containing the reference images')
    parser.add_argument('-m', '--method', default='sscd', type=str, help='Method to compute the similarity score. Select one of [clip, sscd]. (Default: sscd)')
    parser.add_argument('-n', '--name', default='XX',  type=str,help='RTPT user name (Default: XX)')
    parser.add_argument('--num_samples', default=10, type=int, help='Number of samples per prompt to compute the similarity score (Default: 10)')
    parser.add_argument('-p', '--prompts', type=str, help='csv file containing the prompts')

    args = parser.parse_args()
    
    compute_similarity_scores(args)