import torch
import random
import pandas as pd
from PIL import Image
import clip.clip as clip
import matplotlib.pyplot as plt
from torchvision import transforms

import mlot_list as mlot

def setSeed(seed):
    random.seed(seed)
    torch.manual_seed(seed)

seed = 3407
setSeed(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("[Work on {}]".format(device))

# ViT-B-32; RN50x64
model_name = 'ViT-B-32'
model_path = '/home_new/shiliangliang/taorui/retrieval/cache/model/{}.pt'
model, preprocess = clip.load(model_path.format(model_name), device=device)
model.eval()
print("[Clip Model {} Loaded]".format(model_name))


image_dir_path = '/home_new/shiliangliang/taorui/retrieval/flickr30k/flickr30k-images/'
annotations = pd.read_table('./results_20130124.token', sep='\t', header=None, names=['image', 'caption'])
# ============================================
# Total Image: 31783, Total Caption: 158915
# ============================================


def baseline():
    image_embeddings = torch.load('./cache/flickr30k_cache/image_embeddings({}){}.pt'.format(model_name, seed))
    print("image embedding shape:", image_embeddings.shape)
    caption_embeddings = torch.load('./cache/flickr30k_cache/caption_embeddings({}){}.pt'.format(model_name, seed))
    print("caption embedding shape:", caption_embeddings.shape)
    sim = (100 * image_embeddings @ caption_embeddings.T).softmax(dim=-1) # shape = (图片张数, 标签条数)
    print("Similarity shape:", sim.shape)

    # Image->Text 检索
    i2t_result = [0, 0, 0]
    for i in range(image_embeddings.shape[0]):
        predict_for_i = sim[i].topk(10).indices
        gt = [5*i, 5*i+1, 5*i+2, 5*i+3, 5*i+4]
        for predict in predict_for_i[:1]:
            if predict in gt:
                i2t_result[0] += 1
                break
        for predict in predict_for_i[:5]:
            if predict in gt:
                i2t_result[1] += 1
                break
        for predict in predict_for_i[:10]:
            if predict in gt:
                i2t_result[2] += 1
                break
    i2t_result = [r/image_embeddings.shape[0] for r in i2t_result]
    print("[Baseline {}] Image->Text: R1:{:.4f}, R5:{:.4f}, R10:{:.4f}".format(model_name, *i2t_result))

    del predict_for_i, predict
    torch.cuda.empty_cache()

    # Text->Image 检索
    t2i_result = [0, 0, 0]
    sim = sim.T
    for i in range(caption_embeddings.shape[0]):
        predict_for_i = sim[i].topk(10).indices
        gt = i // 5
        for predict in predict_for_i[:1]:
            if predict == gt:
                t2i_result[0] += 1
                break
        for predict in predict_for_i[:5]:
            if predict == gt:
                t2i_result[1] += 1
                break
        for predict in predict_for_i[:10]:
            if predict == gt:
                t2i_result[2] += 1
                break
    t2i_result = [r/caption_embeddings.shape[0] for r in t2i_result]
    print("[Baseline {}] Text->Image: R1:{:.4f}, R5:{:.4f}, R10:{:.4f}".format(model_name, *t2i_result))


def I2TRetrievalMLOT(transform, bs):
    ids = list(range(31783))
    random.shuffle(ids)
    correct = [0, 0, 0]

#     fig, axs = plt.subplots(1, 5, figsize=(20, 4))

    for i in range(0, 31783, bs):
        batch_end = min(i+bs, 31783)
        batch_ids = ids[i: batch_end]
        batch_ids = [id * 5 for id in batch_ids]    # 转成 annotation 中的标号
        test_img_paths = [image_dir_path + annotations['image'][id][:-2] if annotations['image'][id][-1] == '0' else 1/0 for id in batch_ids]
        gt_captions = [[id, id+1, id+2, id+3, id+4] for id in batch_ids]

        test_imgs = [Image.open(img_path).convert('RGB') for img_path in test_img_paths]
#         for jj in range(5):
#             axs[jj].imshow(test_imgs[jj])
#             axs[jj].axis('off')
#             print("gt_captions[jj] = ", gt_captions[jj])
#             print("gt_captions[jj][0] = ", gt_captions[jj][0])
#             print("Caption {}: {}".format(jj, annotations['caption'][gt_captions[jj][0]]))
#         plt.savefig('./flickr_imgs.png')
#         assert False
        test_imgs = [transforms.ToTensor()(img) for img in test_imgs]
        arg_imgs = [transform(img) for img in test_imgs]
        test_img_inputs = [preprocess(img) for img in test_imgs]    # 预处理原图
        arg_imgs_inputs = [preprocess(img) for img in arg_imgs]     # 预处理增广图

        # 生成图片特征 形状均为 (bs, 512)
        test_img_feats = model.encode_image(torch.stack(test_img_inputs).to(device))
        test_img_feats /= test_img_feats.norm(dim=-1, keepdim=True)
        arg_img_feats = model.encode_image(torch.stack(arg_imgs_inputs).to(device))
        arg_img_feats /= arg_img_feats.norm(dim=-1, keepdim=True)
        print("image features shape:", test_img_feats.shape, arg_img_feats.shape)

        sim = torch.mm(test_img_feats, all_caption_embeddings.T)
        sim_arg = torch.mm(arg_img_feats, all_caption_embeddings.T)
        print("Similarity shape:", sim.shape, sim_arg.shape)

        del test_imgs, arg_imgs, test_img_inputs, arg_imgs_inputs
        torch.cuda.empty_cache()

        # 转换到多层OT问题
        k = batch_end - i
        s, t = torch.ones(k) / k, torch.ones(k) / k
        print("simi max:", sim.max().item(), sim_arg.max().item())
        print("simi min:", sim.min().item(), sim_arg.min().item())
        _simi_max = max(sim.max().item(), sim_arg.max().item())
        _simi_min = min(sim.min().item(), sim_arg.min().item())
        M = [_simi_max+0.5 - sim, _simi_max+0.5 - sim_arg.T]
        _max_num = _simi_max+0.5 - _simi_min
        M = [Mi / _max_num for Mi in M]
        del sim, sim_arg
        torch.cuda.empty_cache()
        T = mlot.multi_sinkhorn_single(s, t, M, 5e-2, numItermax=1500)
        predict_sim = T[0] + T[1].T
        print("Prediction sim matrix shape:", predict_sim.shape)

        # 每行取topk作为预测caption的编号
        values, predict_cap_ids = torch.topk(predict_sim, 10, dim=1)
        print("Predicts shape:", predict_cap_ids.shape)
        batch_correct = [0, 0, 0]

        for j in range(k):
            pre_caps = predict_cap_ids[j]
            gt_caps = gt_captions[j]
            for predict in pre_caps[:1]:
                if predict in gt_caps:
                    batch_correct[0] += 1
                    break
            for predict in pre_caps[:5]:
                if predict in gt_caps:
                    batch_correct[1] += 1
                    break
            for predict in pre_caps[:10]:
                if predict in gt_caps:
                    batch_correct[2] += 1
                    break

        correct = [correct[_] + batch_correct[_] for _ in range(3)]
        batch_correct = [bc / k for bc in batch_correct]
        print("[Batch {}] R@1: {}, R@5: {}, R@10: {}".format(batch_end/k, *batch_correct))

    correct = [ci / 31783 for ci in correct]
    print("[Image->Text] [{}]: R@1: {}, R@5: {}, R@10: {}".format(model_name, *correct))


def T2IRetrievalMLOT(bs):
    ids = list(range(31783))
    random.shuffle(ids)
    correct = [0, 0, 0]
    for i in range(0, 31783, bs):
        batch_end = min(i+bs, 31783)
        batch_ids = ids[i: batch_end]
        all_caps_for_one_img = [[5*id, 5*id+1, 5*id+2, 5*id+3, 5*id+4] for id in batch_ids]
        test_caption_ids = [random.choices(caps, k=2) for caps in all_caps_for_one_img]
        first_test_captions = [annotations['caption'][one_img_caps[0]] for one_img_caps in test_caption_ids]
        second_test_captions = [annotations['caption'][one_img_caps[1]] for one_img_caps in test_caption_ids]
        gt_imgs = batch_ids

#         fig, axs = plt.subplots(1, 5, figsize=(20, 4))
#         for jj in range(5):
#             axs[jj].imshow(Image.open(image_dir_path + annotations['image'][5*gt_imgs[jj]][:-2]).convert('RGB'))
#             axs[jj].axis('off')
#             print("Caption {}: {} ||||| {}".format(i, first_test_captions[jj], second_test_captions[jj]))
#         plt.savefig('./flickr_imgs.png')
#         assert False

        first_query_tokens = clip.tokenize(first_test_captions, truncate=True).to(device)
        second_query_tokens = clip.tokenize(second_test_captions, truncate=True).to(device)
        first_query_features = model.encode_text(first_query_tokens)
        second_query_features = model.encode_text(second_query_tokens)
        first_query_features /= first_query_features.norm(dim=-1, keepdim=True)
        second_query_features /= second_query_features.norm(dim=-1, keepdim=True)
        print("Text query feature shape:", first_query_features.shape, second_query_features.shape)

        sim = torch.mm(first_query_features, all_image_embeddings.T)
        arg_sim = torch.mm(second_query_features, all_image_embeddings.T)
        print("Similarity shape:", sim.shape, arg_sim.shape)

        # 转换到多层OT问题
        k = batch_end - i
        s, t = torch.ones(k) / k, torch.ones(k) / k
        _simi_max = max(sim.max().item(), arg_sim.max().item())
        _simi_min = min(sim.min().item(), arg_sim.min().item())
        print("similarity max=", _simi_max, ", min=", _simi_min)
        M = [_simi_max+0.5 - sim, _simi_max+0.5 - arg_sim.T]
        _max_num = _simi_max+0.5 - _simi_min
        M = [Mi / _max_num for Mi in M]
        del sim, arg_sim
        torch.cuda.empty_cache()
        T = mlot.multi_sinkhorn_single(s, t, M, 5e-2, numItermax=1500)
        predict_sim = T[0] + T[1].T
        print("Prediction shape:", predict_sim.shape)

        # 每行取topk作为预测caption的编号
        values, predict_cap_ids = torch.topk(predict_sim, 10, dim=1)
        batch_correct = [0, 0, 0]
        for j in range(k):
            pre_imgs = predict_cap_ids[j]
            gt = gt_imgs[j]
            for predict in pre_imgs[:1]:
                if predict.item() == gt:
                    batch_correct[0] += 1
                    break
            for predict in pre_imgs[:5]:
                if predict.item() == gt:
                    batch_correct[1] += 1
                    break
            for predict in pre_imgs[:10]:
                if predict.item() == gt:
                    batch_correct[2] += 1
                    break

        correct = [correct[_] + batch_correct[_] for _ in range(3)]
        batch_correct = [bc / k for bc in batch_correct]
        print("[Batch {}] R@1: {}, R@5: {}, R@10: {}".format(batch_end/k, *batch_correct))

    correct = [ci / 31783 for ci in correct]
    print("[Text->Image] [{}]: R@1: {}, R@5: {}, R@10: {}".format(model_name, *correct))


def randomTransforms(k):
    all_transforms = [
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        transforms.RandomGrayscale(p=1.0),
        transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
    ]

    # selected_transforms = random.sample(all_transforms, k)
    # return transforms.Compose(selected_transforms)
    return all_transforms[k]


if __name__ == '__main__':
    #         Image->Text  Text->Image  Baseline
    test_tpye = [0,          1,          0]
    if model_name == 'RN50x64':
        batch_size = 250
    else:
        batch_size = 500

    if test_tpye[0]:
        with torch.no_grad():
            all_caption_embeddings = torch.load('./cache/flickr30k_cache/caption_embeddings({}){}.pt'.format(model_name, seed))
            print("caption embedding shape:", all_caption_embeddings.shape)
            trans = randomTransforms(3)
            I2TRetrievalMLOT(trans, batch_size)
    elif test_tpye[1]:
        with torch.no_grad():
            all_image_embeddings = torch.load('./cache/flickr30k_cache/image_embeddings({}){}.pt'.format(model_name, seed))
            print("image embedding shape:", all_image_embeddings.shape)
            T2IRetrievalMLOT(batch_size)
    else:
        print("Baseline on", model_name)
        baseline()