import os
import torch
import random
import pandas as pd
from PIL import Image
import clip.clip as clip
from torchvision import transforms


def setSeed(seed):
    random.seed(seed)
    torch.manual_seed(seed)

seed = 3407
setSeed(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("[Work on {}]".format(device))

# ViT-B-32; RN50x64
model_name = 'ViT-B-32'
model_path = '/home_new/shiliangliang/taorui/retrieval/cache/model/{}.pt'
model, preprocess = clip.load(model_path.format(model_name), device=device)
model.eval()
print("[Clip Model Loaded]")


image_dir_path = '/home_new/shiliangliang/taorui/retrieval/flickr30k/flickr30k-images/'
annotations = pd.read_table('./results_20130124.token', sep='\t', header=None, names=['image', 'caption'])
# ============================================
# Total Image: 31783, Total Caption: 158915
# ============================================


def collectImageEmbedding(start, end, bs):
    imagenum = 0
    all_image_embeddings = []
    for i in range(start, end, bs):
        if i+bs >= end:
            image_names = [annotations['image'][j][:-2] for j in range(i, end) if annotations['image'][j][-1]=='0']
        else:
            image_names = [annotations['image'][j][:-2] for j in range(i, i+bs) if annotations['image'][j][-1]=='0']
        batch_imgs = [Image.open(image_dir_path+img_name).convert('RGB') for img_name in image_names]
        batch_imgs = [transforms.ToTensor()(img) for img in batch_imgs]
        batch_imgs = [preprocess(img) for img in batch_imgs]
        batch_imgs = torch.stack(batch_imgs).to(device)
        print(i, "| image preprocess shape:", batch_imgs.shape, end=", ")
        with torch.no_grad():
            image_features = model.encode_image(batch_imgs)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            print("feature shape: {}".format(image_features.shape))
            all_image_embeddings.append(image_features)

        imagenum += len(image_names)
        # 释放内存
        del batch_imgs, image_features
        torch.cuda.empty_cache()

    all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
    print("All image features shape:", all_image_embeddings.shape)

    torch.save(all_image_embeddings, './cache/flickr30k_cache/image_embeddings({}){}.pt'.format(model_name, seed))


def collectCaptionEmbedding(start, end, bs):
    captionnum = 0
    all_captions = []
    all_caption_embeddings = []
    for i in range(start, end, bs):
        batch_end = min(i + bs, end)
        batch_captions = [annotations['caption'][j] for j in range(i, batch_end)]
        all_captions.extend(batch_captions)
        text_tokens = clip.tokenize(batch_captions, truncate=True).to(device)
        print(i, "| caption preprocess shape:", text_tokens.shape, end=", ")
        with torch.no_grad():
            text_features = model.encode_text(text_tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            print("feature shape: {}".format(text_features.shape))
            all_caption_embeddings.append(text_features)

        captionnum += len(batch_captions)
        # 释放内存
        del text_tokens, text_features
        torch.cuda.empty_cache()

    all_caption_embeddings = torch.cat(all_caption_embeddings, dim=0)
    print("All caption features shape:", all_caption_embeddings.shape)
    print("All caption number:", captionnum, len(all_captions))
    torch.save(all_caption_embeddings, './cache/flickr30k_cache/caption_embeddings({}){}.pt'.format(model_name, seed))
    coded_captions = [s.encode('utf-8') for s in all_captions]
    torch.save(coded_captions, './cache/flickr30k_cache/caption_strings({}){}.pt'.format(model_name, seed))
    # 读取后使用 [s.decode('utf-8') for s in ...] 还原


def countTooLongCaption():
    too_long_caption_num = 0
    too_long_length = []
    for i in range(0, 158915):
        if len(annotations['caption'][i]) > 77:
            too_long_caption_num += 1
            too_long_length.append(len(annotations['caption'][i]))
    print("Too long caption number:", too_long_caption_num)
    print("average length:", sum(too_long_length)/len(too_long_length))


def checkLossCaption():
    # 31784 * 5 = 158920 比实际的caption多5个, 下面遍历检查缺在哪张图片上
    imagenum = 0
    for i in range(0, 158915, 5):
        cap_nums = [annotations['image'][j][-1] for j in range(i, i+5)]
        imagenum += 1
        if cap_nums != ['0', '1', '2', '3', '4']:
            print("Wrong:", i, cap_nums)
    print("Total image number:", imagenum)
    # 原来是只有 31783 张照片地址, 即 embedding 也是 (31783, 1024) 形状的


if __name__ == '__main__':
#     collectImageEmbedding(0, 158915, 5*100)
#     collectCaptionEmbedding(0, 158915, 500)
#     countTooLongCaption()
    checkLossCaption()
