import os
import torch
import random
from PIL import Image
import clip.clip as clip
from pycocotools.coco import COCO
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


import mlot_simple as mlot


def setSeed(seed):
    random.seed(seed)
    torch.manual_seed(seed)

seed = 3407
setSeed(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("[Work on {}]".format(device))
# ViT-B-32; RN50x64
model_name = 'ViT-B-32'
path = '/home_new/shiliangliang/taorui/retrieval/cache/model/{}.pt'
model, preprocess = clip.load(path.format(model_name), device=device)
model.eval()
print("[Clip Model Loaded]")
# 加载caption语句, 和标准做对比
if model_name in ['ViT-B-32', 'ViT-B/32']:
    model_name = 'all'


all_captions = torch.load("./cache/caption_strings({}){}.pt".format(model_name, seed))
all_captions = [s.decode('utf-8').strip('\n') for s in all_captions]
print("[All Captions Loaded]")

# Image->Text 加载全部caption
all_caption_embeddings = torch.load('./cache/caption_embeddings({}){}.pt'.format(model_name, seed)).to(device)
print("All caption embeddings shape:", all_caption_embeddings.shape)
# Text->Image 加载全部image
# all_images_embeddings = torch.load('./cache/image_embeddings({}){}.pt'.format(model_name, seed)).to(device)
# print("All image embeddings shape:", all_images_embeddings.shape)


class CocoCaptionsDataset(Dataset):
    def __init__(self, annotation_path, img_folder, preprocess, transform):
        self.coco = COCO(annotation_path)
        self.img_folder = img_folder
        self.preprocess = preprocess
        self.transform = transform
        self.img_ids = self.coco.getImgIds()

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_folder, img_info['file_name'])

        # _img = Image.open(img_path).convert('RGB')
        # _to_tensor = transforms.Compose([
        #     transforms.Resize((600, 400)),  # 先随便resize一个, 保证图像形状一致, 具体resize等到clip中再做
        #     transforms.ToTensor()
        # ])
        # _arg_img = _to_tensor(self.transform(_img))
        # _img = _to_tensor(_img)

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        captions = [ann['caption'] for ann in self.coco.loadAnns(ann_ids)]

        return img_path, captions


def randomTransforms(k):
    all_transforms = [
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        transforms.RandomGrayscale(p=1.0),
        transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
    ]

    # selected_transforms = random.sample(all_transforms, k)
    # return transforms.Compose(selected_transforms)
    return all_transforms[k]


def solveMLOT(s, t, M, eps, numItermax):
    if len(eps) == 1:
        T, log = mlot.multi_sinkhorn_single(s, t, M, eps[0], numItermax)
    else:
        T, log = mlot.multi_sinkhorn(s, t, M, eps[0], eps[1], numItermax)
    return T


def checkResult(pre_caps, gt_caps):
    correct = torch.zeros(3, dtype=torch.float32)
    match = [pre_caps[i] in gt_caps for i in range(len(pre_caps))]
    correct[0] = sum(match[:1]) / len(gt_caps)
    correct[1] = sum(match[:5]) / len(gt_caps)
    correct[2] = sum(match) / len(gt_caps)
    return correct


def I2Tmain(model, dataset, transform, preprocess, query_type, k, test_ids=None):
    assert query_type == 'image', "Only support image query now (Image->Text Retrieval)"
    if test_ids is None:
        print("Random test {} id".format(k))
        test_ids = random.sample(range(len(dataset)), k)
    test_data = [dataset[i] for i in test_ids]
    test_img_paths = [item[0] for item in test_data]
    test_captions = [item[1] for item in test_data]
    assert len(test_img_paths) == len(test_captions) == k, "Data length wrong: {} and {}".format(len(test_data), len(test_captions))

    test_imgs = [Image.open(img_path).convert('RGB') for img_path in test_img_paths]
    test_imgs = [transforms.ToTensor()(img) for img in test_imgs]
    arg_imgs = [transform(img) for img in test_imgs]
    test_img_inputs = [preprocess(img) for img in test_imgs]    # 预处理原图
    arg_imgs_inputs = [preprocess(img) for img in arg_imgs]     # 预处理增广图

    # 生成图片特征 形状均为 (k, 512)
    test_img_feats = model.encode_image(torch.stack(test_img_inputs).to(device))
    arg_img_feats = model.encode_image(torch.stack(arg_imgs_inputs).to(device))
    print("image features shape:", test_img_feats.shape, arg_img_feats.shape)

    sim = torch.mm(test_img_feats, all_caption_embeddings.T)
    sim_arg = torch.mm(arg_img_feats, all_caption_embeddings.T)
    print("Similarity shape:", sim.shape, sim_arg.shape)

    del test_imgs, arg_imgs, test_img_inputs, arg_imgs_inputs

    # 转换到多层OT问题
    s, t = torch.ones(k) / k, torch.ones(k) / k
    print("simi max:", sim.max().item(), sim_arg.max().item())
    print("simi min:", sim.min().item(), sim_arg.min().item())
    _simi_max = max(sim.max().item(), sim_arg.max().item())
    _simi_min = min(sim.min().item(), sim_arg.min().item())
    # M = [_simi_max+1.5-sim, _simi_max+1.5-sim_arg.T]
    M = [_simi_max+6-sim, _simi_max+6-sim_arg.T]
    assert torch.all(M[0] > 0) and torch.all(M[1] > 0), "Similarity matrix should be positive"
    # _max_num = _simi_max + 1.5 - _simi_min
    _max_num = _simi_max+6 - _simi_min
    M = [Mi / _max_num for Mi in M]
    del sim, sim_arg
    torch.cuda.empty_cache()
    T = solveMLOT(s, t, M, [5e-2], numItermax=2000)
    predict_sim = T[0] + T[1].T
    print("Prediction sim matrix shape:", predict_sim.shape)

    # 每行取topk作为预测caption的编号
    values, predict_cap_ids = torch.topk(predict_sim, 10, dim=1)
    correct = torch.zeros(3, dtype=torch.float32)     # 分别是R1, R5, R10

    for i in range(k):
        pre_caps = [all_captions[predict_cap_ids[i][p]] for p in range(10)]
        gt_caps = [cap.strip('\n') for cap in test_captions[i]]
        # results = checkResult(pre_caps, gt_caps)
        # correct += results
        for j, cap in enumerate(pre_caps):
            if cap in gt_caps:
                if j < 1:
                    correct[0] += 1
                if j < 5:
                    correct[1] += 1
                if j < 10:
                    correct[2] += 1
                break
            else:
                continue

        if i % 100 == 0:
            print(i, "| predictions:", pre_caps)
            print(i, "| ground truth:", test_captions[i])

    res = [correct[i].item() / k for i in range(3)]
    _print_res = "R@1: {} | R@5: {} | R@10: {}".format(
        round(correct[0].item()/k, 4),
        round(correct[1].item()/k, 4),
        round(correct[2].item()/k, 4)
    )

    del s, t, M, T, predict_sim, values, predict_cap_ids
    torch.cuda.empty_cache()

    print(_print_res)
    return res


def I2Tbaseline(k, test_ids):
    # assert query_type == 'image', "This is Image->Text Retrieval baseline"
    if test_ids is None:
        print("Random test {} id".format(k))
        test_ids = random.sample(range(len(dataset)), k)
    test_data = [dataset[i] for i in test_ids]
    test_img_paths = [item[0] for item in test_data]
    test_captions = [item[1] for item in test_data]

    test_imgs = [Image.open(img_path).convert('RGB') for img_path in test_img_paths]
    test_imgs = [transforms.ToTensor()(img) for img in test_imgs]
    test_img_inputs = [preprocess(img) for img in test_imgs]  # 预处理原图

    # 生成图片特征 形状均为 (k, 512) 或者 (k, 1024)
    test_img_feats = model.encode_image(torch.stack(test_img_inputs).to(device))
    print("image features shape:", test_img_feats.shape)

    predict_sim = (test_img_feats @ all_caption_embeddings.T).softmax(dim=-1)
    print("Similarity shape:", predict_sim.shape)

    values, predict_cap_ids = torch.topk(predict_sim, 10, dim=1)
    correct = torch.zeros(3, dtype=torch.float32)  # 分别是R1, R5, R10

    for i in range(k):
        pre_caps = [all_captions[predict_cap_ids[i][p]] for p in range(10)]
        gt_caps = [cap.strip('\n') for cap in test_captions[i]]
        # results = checkResult(pre_caps, gt_caps)
        # correct += results
        for j, cap in enumerate(pre_caps):
            if cap in gt_caps:
                if j < 1:
                    correct[0] += 1
                if j < 5:
                    correct[1] += 1
                if j < 10:
                    correct[2] += 1
                break
            else:
                continue

        if i % 100 == 0:
            print(i, "| predictions:", pre_caps)
            print(i, "| ground truth:", test_captions[i])

    res = [correct[i].item() / k for i in range(3)]
    _print_res = "R@1: {} | R@5: {} | R@10: {}".format(
        round(correct[0].item() / k, 4),
        round(correct[1].item() / k, 4),
        round(correct[2].item() / k, 4)
    )

    del test_imgs, test_img_inputs, test_img_feats, all_caption_embeddings, predict_sim, values, predict_cap_ids
    torch.cuda.empty_cache()

    print(_print_res)
    return res


def T2Imain(model, dataset, query_type, k, test_ids=None):
    assert query_type == 'text', "Only support text query now (Text->Image Retrieval)"
    if test_ids is None:
        print("Random test {} id".format(k))
        test_ids = random.sample(range(len(dataset)), k)
    test_data = [dataset[i] for i in test_ids]
    test_img_paths = [item[0] for item in test_data]
    test_captions = [item[1] for item in test_data]
    two_query_captions = [random.sample(test_captions[i], 2) for i in range(k)]  # 取两个 分别作query和数据增广
    first_query, second_query = [capi[0] for capi in two_query_captions], [capi[1] for capi in two_query_captions]

    first_query_tokens = clip.tokenize(first_query).to(device)
    second_query_tokens = clip.tokenize(second_query).to(device)
    first_query_features = model.encode_text(first_query_tokens)
    second_query_features = model.encode_text(second_query_tokens)
    first_query_features /= first_query_features.norm(dim=-1, keepdim=True)
    second_query_features /= second_query_features.norm(dim=-1, keepdim=True)
    print("Text query feature shape:", first_query_features.shape, second_query_features.shape)

    sim = torch.mm(first_query_features, all_images_embeddings.T)
    sim_arg = torch.mm(second_query_features, all_images_embeddings.T)
    print("Similarity shape:", sim.shape, sim_arg.shape)

    del first_query_tokens, second_query_tokens, first_query_features, second_query_features
    torch.cuda.empty_cache()

    # 转换到多层OT问题
    s, t = torch.ones(k) / k, torch.ones(k) / k
    print("simi max:", sim.max().item(), sim_arg.max().item())
    print("simi min:", sim.min().item(), sim_arg.min().item())
    _simi_max = max(sim.max().item(), sim_arg.max().item())
    _simi_min = min(sim.min().item(), sim_arg.min().item())
    M = [8-sim, 8-sim_arg.T]
    assert torch.all(M[0] > 0) and torch.all(M[1] > 0), "Similarity matrix should be positive"
    _max_num = 8 - _simi_min
    M = [Mi / _max_num for Mi in M]
    del sim, sim_arg
    torch.cuda.empty_cache()
    T = solveMLOT(s, t, M, [5e-3], numItermax=2000)
    predict_sim = T[0] + T[1].T
    print("Prediction sim matrix shape:", predict_sim.shape)

    values, predict_img_ids = torch.topk(predict_sim, 10, dim=1)
    correct = torch.zeros(3, dtype=torch.float32)  # 分别是R1, R5, R10

    for i in range(k):
        pre_imgs = [predict_img_ids[i][p] for p in range(10)]
        gt_img = test_ids[i]
        if gt_img in pre_imgs[:1]:
            correct[0] += 1
        if gt_img in pre_imgs[:5]:
            correct[1] += 1
        if gt_img in pre_imgs:
            correct[2] += 1

        if i % 50 == 0:
            print(i, "| predictions:", pre_imgs)
            print(i, "| ground truth:", gt_img)

    res = [correct[i].item() / k for i in range(3)]
    _print_res = "R@1: {} | R@5: {} | R@10: {}".format(
        round(correct[0].item() / k, 4),
        round(correct[1].item() / k, 4),
        round(correct[2].item() / k, 4)
    )

    del s, t, M, T, predict_sim, values, predict_img_ids
    torch.cuda.empty_cache()

    print(_print_res)
    return res



def T2Ibaseline(model, dataset, query_type, all_images_embeddings, k, test_ids=None):
    assert query_type == 'text', "This is Text->Image Retrieval baseline"
    if test_ids is None:
        print("Random test {} id".format(k))
        test_ids = random.sample(range(len(dataset)), k)

    test_data = [dataset[i] for i in test_ids]
    test_img_paths = [item[0] for item in test_data]
    test_captions = [item[1] for item in test_data]
    query_captions = [random.sample(test_captions[i], 1)[0] for i in range(k)]  # 随机取一个作为query

    query_tokens = clip.tokenize(query_captions).to(device)
    query_features = model.encode_text(query_tokens)
    query_features /= query_features.norm(dim=-1, keepdim=True)
    print("Text query feature shape:", query_features.shape)

    predict_sim = (query_features @ all_images_embeddings.T).softmax(dim=-1)
    print("Similarity shape:", predict_sim.shape)

    values, predict_img_ids = torch.topk(predict_sim, 10, dim=1)
    correct = torch.zeros(3, dtype=torch.float32)  # 分别是R1, R5, R10

    for i in range(k):
        pre_imgs = [predict_img_ids[i][p] for p in range(10)]
        gt_img = test_ids[i]
        if gt_img in pre_imgs[:1]:
            correct[0] += 1
        if gt_img in pre_imgs[:5]:
            correct[1] += 1
        if gt_img in pre_imgs:
            correct[2] += 1

        if i % 50 == 0:
            print(i, "| predictions:", pre_imgs)
            print(i, "| ground truth:", gt_img)

    res = [correct[i].item() / k for i in range(3)]
    _print_res = "R@1: {} | R@5: {} | R@10: {}".format(
        round(correct[0].item() / k, 4),
        round(correct[1].item() / k, 4),
        round(correct[2].item() / k, 4)
    )

    del query_tokens, query_features, predict_sim, values, predict_img_ids
    torch.cuda.empty_cache()

    print(_print_res)
    return res


def collectCaptionEmbedding(dataloader, model_name, seed):
    all_captions = []
    all_caption_embeddings = []
    for i, item in enumerate(dataloader):
        batch_captions = [subitem[1] for subitem in item]   # 这里列表长bs, 每张图片的所有cap都含在另一个列表中
        batch_captions = [capi[i] for capi in batch_captions for i in range(len(capi))]     # 展开所有子列表, 对所有cap一视同仁
        all_captions.extend(batch_captions)     # 存储所有的caption必须严格按照顺序
        print("example caps:", batch_captions[0:50:10])

        text_tokens = clip.tokenize(batch_captions).to(device)
        print("token shape: {}".format(text_tokens.shape), end=", ")
        with torch.no_grad():
            text_features = model.encode_text(text_tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            print("feature shape: {}".format(text_features.shape))
            all_caption_embeddings.append(text_features)
        # 释放内存
        # del text_tokens, text_features
        # torch.cuda.empty_cache()

    all_caption_embeddings = torch.cat(all_caption_embeddings, dim=0)
    print("All text features shape:", all_caption_embeddings.shape)

    # 分别存储cap的文字和embedding向量
    if model_name in ['ViT-B-32', 'ViT-B/32']:
        model_name = 'all'
    torch.save(all_caption_embeddings, './cache/caption_embeddings({}){}.pt'.format(model_name, seed))
    coded_captions = [s.encode('utf-8') for s in all_captions]
    torch.save(coded_captions, './cache/caption_strings({}){}.pt'.format(model_name, seed))
    # 读取后使用 [s.decode('utf-8') for s in ...] 还原


def collectImageEmbedding(dataloader, model_name, seed):
    all_image_embeddings = []
    for i, item in enumerate(dataloader):
        batch_imgs = [subitem[0] for subitem in item]
        batch_imgs = [Image.open(img_path).convert('RGB') for img_path in batch_imgs]
#         batch_imgs = [transforms.ToTensor()(img) for img in batch_imgs]
        batch_imgs = [preprocess(img) for img in batch_imgs]
        batch_imgs = torch.stack(batch_imgs).to(device)
        print("image shape:", batch_imgs.shape, end=", ")
        with torch.no_grad():
            image_features = model.encode_image(batch_imgs)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            print("feature shape: {}".format(image_features.shape))
            all_image_embeddings.append(image_features)
        # 释放内存
        # del batch_imgs, image_features
        # torch.cuda.empty_cache()

    all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
    print("All image features shape:", all_image_embeddings.shape)

    torch.save(all_image_embeddings, './cache/image_embeddings({}){}.pt'.format(model_name, seed))


def countTotalCaptions(dataloader):
    # 统计所有caption的数量 (包含描述同一张照片的)
    cnt = 0
    def countList(ll):
        if type(ll) != str:
            return sum([countList(subll) for subll in ll])
        return 1
    for i, item in enumerate(dataloader):
        captions = [subitem[1] for subitem in item]
        cnt += countList(captions)
        if i % 100 == 0:
            print("{} | {}".format(i, cnt))
    print("Total captions:", cnt)


def runMLOTI2TRetrieval():
    if model_name == 'all':
        batch_size = 1000
    else:
        batch_size = 250
    res = [0, 0, 0]

    with torch.no_grad():
        for group in range(len(TotalImageNumRange) // batch_size):
            test_ids = TotalImageNumRange[group*batch_size: (group+1)*batch_size]
            _tmp = I2Tmain(model, dataset, randomTransforms(6), preprocess, 'image', batch_size, test_ids=test_ids)
            # with open('./results/mlot0-1000(3e-2).txt', 'a') as file:
            #     file.write(str(_tmp) + '\n')
            res = [res[i] + _tmp[i] for i in range(3)]
        res = [r / (len(TotalImageNumRange) // batch_size) for r in res]
        print("Final result:", list(map(lambda x: round(x, 4), res)))


def runI2TBaseline():
    if model_name == 'all':
        batch_size = 1000
    else:
        batch_size = 250
    res = [0, 0, 0]
    with torch.no_grad():
        for group in range(len(TotalImageNumRange) // batch_size):
            test_ids = TotalImageNumRange[group*batch_size: (group+1)*batch_size]
            _tmp = I2Tbaseline(batch_size, test_ids=test_ids)
            # with open('./results/baseline-250.txt', 'a') as file:
            #     file.write(str(_tmp) + '\n')
            res = [res[i] + _tmp[i] for i in range(3)]
        res = [r / (len(TotalImageNumRange) // batch_size) for r in res]
        print("Final result:", list(map(lambda x: round(x, 4), res)))


def runMLOTT2IRetrieval():
    if model_name == 'all':
        batch_size = 1000
    else:
        batch_size = 250
    res = [0, 0, 0]
    with torch.no_grad():
        for group in range(len(TotalImageNumRange) // batch_size):
            test_ids = TotalImageNumRange[group*batch_size: (group+1)*batch_size]
            _tmp = T2Imain(model, dataset, 'text', batch_size, test_ids=test_ids)
            res = [res[i] + _tmp[i] for i in range(3)]
        res = [r / (len(TotalImageNumRange) // batch_size) for r in res]
        print("Final result:", list(map(lambda x: round(x, 4), res)))


def runT2IBaseline():
    batch_size = 250
    res = [0, 0, 0]
    all_images_embeddings = torch.load('./cache/image_embeddings({}){}.pt'.format(model_name, seed)).to(device)
    print("All image embeddings shape:", all_images_embeddings.shape)
    with torch.no_grad():
        for group in range(len(TotalImageNumRange) // batch_size):
            test_ids = TotalImageNumRange[group*batch_size: (group+1)*batch_size]
            _tmp = T2Ibaseline(model, dataset, 'text', all_images_embeddings, batch_size, test_ids=test_ids)
            res = [res[i] + _tmp[i] for i in range(3)]
        res = [r / (len(TotalImageNumRange) // batch_size) for r in res]
        print("Final result:", list(map(lambda x: round(x, 4), res)))


def demoI2T(dataset, all_images_embeddings, k):
    test_ids = random.sample(range(len(dataset)), k)
    print("Random test {}".format(test_ids))
    test_data = [dataset[i] for i in test_ids]
    test_img_paths = [item[0] for item in test_data]
    test_captions = [item[1] for item in test_data]

    test_imgs = [Image.open(img_path).convert('RGB') for img_path in test_img_paths]
    test_imgs = [transforms.ToTensor()(img) for img in test_imgs]
    test_imgs = [transforms.Resize((224, 224))(img) for img in test_imgs]
    flip_imgs = [randomTransforms(0)(img) for img in test_imgs]
    blur_imgs = [randomTransforms(6)(img) for img in test_imgs]
    fig, axs = plt.subplots(3, k, figsize=(20, 12))
    for ii in range(k):
        axs[0,ii].imshow(test_imgs[ii].permute(1, 2, 0))
        axs[0,ii].axis('off')
        axs[1,ii].imshow(flip_imgs[ii].permute(1, 2, 0))
        axs[1,ii].axis('off')
        axs[2,ii].imshow(blur_imgs[ii].permute(1, 2, 0))
        axs[2,ii].axis('off')
        print("{} | {}".format(ii, test_captions[ii]))
    plt.tight_layout()
    plt.savefig('./demo.png')

#     test_img_inputs = [preprocess(img) for img in test_imgs]  # 预处理原图
#     # 生成图片特征 形状均为 (k, 512) 或者 (k, 1024)
#     test_img_feats = model.encode_image(torch.stack(test_img_inputs).to(device))
#     print("image features shape:", test_img_feats.shape)
#
#     predict_sim = (test_img_feats @ all_caption_embeddings.T).softmax(dim=-1)
#     print("Similarity shape:", predict_sim.shape)

#     values, predict_cap_ids = torch.topk(predict_sim, 10, dim=1)
#     for i in range(k):
#         pre_caps = [all_captions[predict_cap_ids[i][p]] for p in range(10)]
#         gt_caps = [cap.strip('\n') for cap in test_captions[i]]


if __name__ == "__main__":
    dataset = CocoCaptionsDataset(
        annotation_path='/mnt/nas/dataset_share/COCO/annotations/captions_val2017.json',
        img_folder='/mnt/nas/dataset_share/COCO/val2017',
        preprocess=preprocess,
        transform=randomTransforms(1)
    )
    data_loader = DataLoader(dataset, batch_size=100, shuffle=False, collate_fn=lambda x: x)

    TotalImageNumRange = list(range(len(dataset)))
    print("Total image number:", len(TotalImageNumRange))
    random.shuffle(TotalImageNumRange)

    # collectCaptionEmbedding(data_loader, model_name, seed)
#     collectImageEmbedding(data_loader, model_name, seed)
    # countTotalCaptions(data_loader)

    # runI2TBaseline()
#     runMLOTI2TRetrieval()
#     runT2IBaseline()
    # runMLOTT2IRetrieval()

    all_images_embeddings = torch.load('./cache/image_embeddings({}){}.pt'.format(model_name, seed)).to(device)
    demoI2T(dataset, all_images_embeddings, 5)
