import torch
import torch.nn.functional as F
import glob
import os
from metric import MRR, Retrieval_metrics
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import requests

def load_COCO(gamma):
    image_dict = torch.load('./embedding/CLIP_COCO_val_image.pt')
    text_dict = torch.load('./embedding/CLIP_COCO_val_text.pt')
    keys = list(image_dict.keys())
    keys.sort()
    image_emb = []
    text_emb = []
    for key in keys:
        image_emb.append(image_dict[key])
        text_emb.append(text_dict[int(key)])
    image_emb = torch.cat(image_emb, dim=0).cuda()
    text_emb = torch.cat(text_emb, dim=0).cuda()
    image_emb = F.normalize(image_emb, dim=-1)
    text_emb = F.normalize(text_emb, dim=-1)
    image_mean = image_emb.mean(dim=0)
    text_mean = text_emb.mean(dim=0)
    distance = image_mean-text_mean
    print('Gamma: ', gamma)
    print('L2 distance:', torch.sum((distance**2)))

    text_emb = text_emb + gamma * distance
    text_emb = F.normalize(text_emb, dim=-1)
    print('refined L2 distance:', torch.sum(((image_emb.mean(dim=0) - text_emb.mean(dim=0)) ** 2)))

    i2t_sim = torch.einsum('nb,tb->nt', image_emb, text_emb)
    t2i_sim = torch.einsum('nb,tb->nt', text_emb, image_emb)
    t2t_sim = torch.einsum('nb,tb->nt', text_emb, text_emb)
    i2i_sim = torch.einsum('nb,tb->nt', image_emb, image_emb)

    print(i2t_sim.mean(), i2t_sim.max())
    print(t2t_sim.mean(), t2t_sim.max())
    print(i2i_sim.mean(), i2i_sim.max())

    return i2t_sim, t2i_sim

def encode_images(image_path, batch_size):
    paths_to_images = glob.glob(os.path.join(image_path, '*.JPEG'))
    CLIP_model = CLIPModel.from_pretrained('./clip-vit-base-patch32/clip-vit-base-patch32')
    CLIP_model.cuda()
    processor = CLIPProcessor.from_pretrained('./clip-vit-base-patch32/clip-vit-base-patch32')
    print(len(paths_to_images))
    image_embs = []
    batch_num = len(paths_to_images) // batch_size
    for i in range(batch_num+1):
        images = []
        paths = paths_to_images[i*batch_size:(i+1)*batch_size]
        for j, path_to_image in enumerate(paths):
            image = Image.open(path_to_image)
            images.append(image)
        print(f'{i+1}/{batch_num}, {(i+1)*batch_size}/{len(paths_to_images)}')
        inputs = processor(text=['temp'], images=images, return_tensors="pt", padding=True)
        for k in inputs.keys():
            inputs[k] = inputs[k].cuda()
        outputs = CLIP_model(**inputs)
        image_embs.append(outputs.image_embeds)
        image.close()
    image_emb = F.normalize(torch.cat(image_embs, dim=0), dim=-1)
    print(image_emb.shape)

    return image_emb