# Our code is a modification of https://github.com/sony/wpse/tree/main .
# We keep the license description in the original code as follows:

# Copyright © 2025 Sony Research Inc.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# ----------------------------------------------------------
# SLIP: https://github.com/facebookresearch/SLIP
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the MIT License
# ----------------------------------------------------------
import argparse
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
import os
import json
import time
import pandas as pd
from PIL import Image
from pycocotools.coco import COCO

import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
import torch.nn.functional as F

from tokenizer import SimpleTokenizer
import utils
import dataset_utils
import losses
from utils import AverageMeter, accuracy


def get_args_parser():
    parser = argparse.ArgumentParser(description="zero-shot evaluations", add_help=False)
    parser.add_argument("--output-dir", default="./", type=str, help="output dir")
    parser.add_argument("--distributed", action="store_true", help="whether evaluating on distributed environment")
    parser.add_argument("--batch-size", default=256, type=int, help="batch_size")
    parser.add_argument("-j", "--workers", default=10, type=int, metavar="N",
                        help="number of data loading workers per process")
    parser.add_argument("--resume", default="", type=str, help="path to latest checkpoint")
    parser.add_argument("--gpu", default=0, type=int, help="gpu id")
    parser.add_argument("--forced", action="store_true", help="When true, all evaluations will be performed even if some datasets were already evaluated.")
    parser.add_argument("--task-list", nargs="+", help="target datasets")
    parser.add_argument("--nrepeat", default=1, type=int, help="num of repetition")

    return parser


def main(args):
    args = utils.init_distributed_mode(args)
    # optionally resume from a checkpoint (takes precedence over autoresume)
    if args.resume:
        ckpt_path = args.resume
        assert os.path.isfile(ckpt_path)
        csv_filename = os.path.join(
                        os.path.splitext(args.resume)[0],
                        "results_retrieval.csv")
        os.makedirs(os.path.splitext(args.resume)[0], exist_ok=True)
    elif os.path.isfile(os.path.join(args.output_dir, "checkpoint_best.pt")):
        ckpt_path = os.path.join(args.output_dir, "checkpoint_best.pt")
        csv_filename = os.path.join(args.output_dir, "results_retrieval.csv")
    else:
        raise Exception("no checkpoint found")

    # state_dict = OrderedDict()
    # for k, v in ckpt["state_dict"].items():
    #     state_dict[k.replace("module.", "")] = v
    
    # create model
    old_args = OmegaConf.load(os.path.join(args.output_dir, "config.yaml"))
    print("=> creating model: {}".format(old_args.model))
    model = instantiate(old_args.model)
    model.cuda(args.gpu)
    ckpt = torch.load(ckpt_path, map_location=torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"), weights_only=False)
    state_dict = ckpt["state_dict"]
    model.load_state_dict(state_dict)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, bucket_cap_mb=200)
        model._set_static_graph()

    print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt["epoch"]))

    # load criterion
    if args.distributed:
        old_args.distributed = True
    else:
        old_args.distributed = False
    criterion = instantiate(old_args.criterion)
    criterion.cuda(args.gpu)

    cudnn.benchmark = True

    cwd = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(cwd, "dataset_catalog.json")) as f:
        catalog = json.load(f)

    # Data loading code
    print("=> creating dataset")
    tokenizer = SimpleTokenizer()
    _, val_transform = dataset_utils.get_img_transform(old_args, mode="pretraining")

    if args.task_list is None:
        task_list = ["cc3m", "mscoco", "flickr30k"]
    else:
        task_list = args.task_list
    
    for i in range(args.nrepeat):
        oneloop(
            csv_filename, task_list, catalog,
            model, criterion, tokenizer, val_transform,
            old_args
            )


def oneloop(csv_filename, task_list, catalog, model, criterion, tokenizer, val_transform, old_args):
    results = pd.DataFrame(columns=["task"])
    topk = (1, 5, 10)

    for d in task_list:
        if d == "cc3m":
            val_dataset_cc3m = dataset_utils.get_dataset(val_transform, tokenizer, old_args, split="validation")
            if args.distributed:
                val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset_cc3m, shuffle=False, seed=0)
            else:
                val_sampler = None

            if isinstance(criterion, losses.KME_CLIP_Loss):
                val_stats_ret_cc3m = validate_retrieval_kme_clip(val_dataset_cc3m, val_sampler, model, criterion, topk, args)
            elif isinstance(criterion, losses.CLIP_WPSE_Loss):
                val_stats_ret_cc3m = validate_retrieval_wpse(val_dataset_cc3m, val_sampler, model, criterion, topk, args)
            elif isinstance(criterion, losses.CLIPLoss):
                val_stats_ret_cc3m = validate_retrieval_clip(val_dataset_cc3m, val_sampler, model, topk, args)

            i2t, t2i = val_stats_ret_cc3m
        elif d == "mscoco":
            image_dir = catalog["mscoco"]["image_dir"]
            annotation_file = catalog["mscoco"]["annotation_file"]
            i2t_all_1 = 0
            t2i_all_1 = 0
            i2t_all_5 = 0
            t2i_all_5 = 0
            i2t_all_10 = 0
            t2i_all_10 = 0
            for i in range(5):
                val_dataset_mscoco = CocoDataset(
                    root_dir=image_dir,
                    annotation_file=annotation_file,
                    tokenizer=tokenizer,
                    transform=val_transform,
                    index=i
                )
                if args.distributed:
                    val_sampler_mscoco = torch.utils.data.distributed.DistributedSampler(val_dataset_mscoco, shuffle=False, seed=0)
                else:
                    val_sampler_mscoco = None

                if isinstance(criterion, losses.KME_CLIP_Loss):
                    val_stats_ret_mscoco = validate_retrieval_kme_clip(val_dataset_mscoco, val_sampler_mscoco, model, criterion, topk, args)
                elif isinstance(criterion, losses.CLIP_WPSE_Loss):
                    val_stats_ret_mscoco = validate_retrieval_wpse(val_dataset_mscoco, val_sampler_mscoco, model, criterion, topk, args)
                elif isinstance(criterion, losses.CLIPLoss):
                    val_stats_ret_mscoco = validate_retrieval_clip(val_dataset_mscoco, val_sampler_mscoco, model, topk, args)
                
                i2t, t2i = val_stats_ret_mscoco
                i2t_all_1 += i2t["r1"]
                t2i_all_1 += t2i["r1"]
                i2t_all_5 += i2t["r5"]
                t2i_all_5 += t2i["r5"]
                i2t_all_10 += i2t["r10"]
                t2i_all_10 += t2i["r10"]
                print(f"acc_mscoco_i2t_{i}={i2t}, acc_mscoco_t2i_{i}={t2i}")

            i2t = {"r1": i2t_all_1 / 5, "r5": i2t_all_5 / 5, "r10": i2t_all_10 / 5}
            t2i = {"r1": t2i_all_1 / 5, "r5": t2i_all_5 / 5, "r10": t2i_all_10 / 5}
        elif d == "flickr30k":
            image_dir = catalog["flickr30k"]["image_dir"]
            annotation_file = catalog["flickr30k"]["annotation_file"]
            i2t_all_1 = 0
            t2i_all_1 = 0
            i2t_all_5 = 0
            t2i_all_5 = 0
            i2t_all_10 = 0
            t2i_all_10 = 0
            for i in range(5):
                val_dataset_flickr30k = Flickr30kDatasetSplit(
                    root_dir=image_dir,
                    json_file=annotation_file,
                    split="test",
                    tokenizer=tokenizer,
                    transform=val_transform,
                    index=i
                )
                if args.distributed:
                    val_sampler_flickr30k = torch.utils.data.distributed.DistributedSampler(val_dataset_flickr30k, shuffle=False, seed=0)
                else:
                    val_sampler_flickr30k = None

                if isinstance(criterion, losses.KME_CLIP_Loss):
                    val_stats_ret_flickr30k = validate_retrieval_kme_clip(val_dataset_flickr30k, val_sampler_flickr30k, model, criterion, topk, args)
                elif isinstance(criterion, losses.CLIP_WPSE_Loss):
                    val_stats_ret_flickr30k = validate_retrieval_wpse(val_dataset_flickr30k, val_sampler_flickr30k, model, criterion, topk, args)
                elif isinstance(criterion, losses.CLIPLoss):
                    val_stats_ret_flickr30k = validate_retrieval_clip(val_dataset_flickr30k,  val_sampler_flickr30k, model, topk, args)

                i2t, t2i = val_stats_ret_flickr30k
                i2t_all_1 += i2t["r1"]
                t2i_all_1 += t2i["r1"]
                i2t_all_5 += i2t["r5"]
                t2i_all_5 += t2i["r5"]
                i2t_all_10 += i2t["r10"]
                t2i_all_10 += t2i["r10"]
                print(f"acc_flickr_i2t_{i}={i2t}, acc_flickr_t2i_{i}={t2i}")
            
            i2t = {"r1": i2t_all_1 / 5, "r5": i2t_all_5 / 5, "r10": i2t_all_10 / 5}
            t2i = {"r1": t2i_all_1 / 5, "r5": t2i_all_5 / 5, "r10": t2i_all_10 / 5}
        
        new_record = pd.DataFrame.from_dict({
            "task": [d] * len(topk),
            "acc": [k for k in i2t.keys()],
            "i2t": [i2t for i2t in i2t.values()],
            "t2i": [t2i for t2i in t2i.values()],
            "timestamp": [time.ctime()] * len(topk),
        })
        print(new_record)
        results = pd.concat([results, new_record], ignore_index=True)

    print("all results:")
    print(results)
    if os.path.isfile(csv_filename):
        results_prev = pd.read_csv(csv_filename)
        results = pd.concat([results_prev, results])
    results.to_csv(csv_filename, index=False)


def validate_retrieval_clip(valid_dataset, val_sampler, model, topk, args):
    i2t_result, t2i_result = {}, {}
    for k in topk:
        i2t_result[str(k)] = AverageMeter(f"Acc@{k}", ":6.2f")
        t2i_result[str(k)] = AverageMeter(f"Acc@{k}", ":6.2f")

    valid_dataset = IndexTrackingDataset(valid_dataset)
    val_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True, sampler=val_sampler, drop_last=False)

    with torch.no_grad():
        model = utils.get_model(model)
        image_list = []
        text_list = []
        index_list = []
        for i, images, texts in val_loader:
            images = images.cuda(args.gpu, non_blocking=True)
            image_list.append(images)
            texts = texts.cuda(args.gpu, non_blocking=True)
            text_list.append(texts)
            index_list.append(i)

        image_list, text_list = torch.cat(image_list, dim=0).to(args.gpu), torch.cat(text_list, dim=0).to(args.gpu)
        index_list = torch.cat(index_list, dim=0).to(args.gpu)
        image_embed = model.encode_image(image_list)
        text_embed = model.encode_text(text_list)

        image_embed_all, text_embed_all, index_list_all = utils.all_gather_batch([image_embed, text_embed, index_list])
        image_embed_all = F.normalize(image_embed_all, p=2, dim=1, eps=0)
        text_embed_all = F.normalize(text_embed_all, p=2, dim=1, eps=0)
        batch_size = 5

        val_loader1 = torch.utils.data.DataLoader(
            valid_dataset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True, sampler=val_sampler, drop_last=False)
       
        for index, images, texts in val_loader1:
            images, texts, index = images.cuda(args.gpu, non_blocking=True), texts.cuda(args.gpu, non_blocking=True), index.cuda(args.gpu, non_blocking=True)
            image_embed = model.encode_image(images)
            text_embed = model.encode_text(texts)
            image_embed = F.normalize(image_embed, p=2, dim=1, eps=0)
            text_embed = F.normalize(text_embed, p=2, dim=1, eps=0)
            logits_per_image, logits_per_text = image_embed @ text_embed_all.t(), text_embed @ image_embed_all.t() 
            
            # identify true label under multi-gpu
            expanded_vector = index_list_all.unsqueeze(1)
            expanded_targets = index.unsqueeze(0)
            matches = (expanded_vector == expanded_targets)
            indices = torch.arange(len(index_list_all), device=index_list_all.device).unsqueeze(1).expand_as(matches)
            masked_indices = torch.where(matches, indices, torch.tensor(len(index_list_all), device=index_list_all.device))
            first_matches, _ = torch.min(masked_indices, dim=0)
            labels = torch.where(first_matches < len(index_list_all), first_matches, torch.tensor(-1, device=index_list_all.device))
            
            acc_list = accuracy(logits_per_image, labels, topk=topk)
            acc_list = utils.scaled_all_reduce(acc_list)
            for i, k in enumerate(topk):
                i2t_result[str(k)].update(acc_list[i].item(), images.size(0))

            acc_list = accuracy(logits_per_text, labels, topk=topk)
            acc_list = utils.scaled_all_reduce(acc_list)
            for i, k in enumerate(topk):
                t2i_result[str(k)].update(acc_list[i].item(), texts.size(0))

    i2t_result_dict = {f"r{k}": i2t_result[str(k)].avg for k in topk}
    t2i_result_dict = {f"r{k}": t2i_result[str(k)].avg for k in topk}

    return i2t_result_dict, t2i_result_dict


def validate_retrieval_wpse(valid_dataset, val_sampler, model, criterion, topk, args):
    i2t_result, t2i_result = {}, {}
    for k in topk:
        i2t_result[str(k)] = AverageMeter(f"Acc@{k}", ":6.2f")
        t2i_result[str(k)] = AverageMeter(f"Acc@{k}", ":6.2f")

    valid_dataset = IndexTrackingDataset(valid_dataset)
    val_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True, sampler=val_sampler, drop_last=False)
    with torch.no_grad():
        model = utils.get_model(model)
        image_list = []
        text_list = []
        index_list = []
        for i, images, texts in val_loader:
            images = images.cuda(args.gpu, non_blocking=True)
            image_list.append(images)
            texts = texts.cuda(args.gpu, non_blocking=True)
            text_list.append(texts)
            index_list.append(i)

        image_list, text_list = torch.cat(image_list, dim=0).to(args.gpu), torch.cat(text_list, dim=0).to(args.gpu)
        index_list = torch.cat(index_list, dim=0).to(args.gpu)
        image_embed, image_weight = model.encode_image(image_list)
        text_embed, text_weight = model.encode_text(text_list)
        if criterion.enable_linear_kernel:
            z_image, z_text = criterion.comb_feature(image_embed, image_weight, text_embed, text_weight)
        else:
            z_image, z_text = criterion.rff_trick.forward_with_weights(image_embed, image_weight, text_embed, text_weight)

        z_image_all, z_text_all, index_list_all = utils.all_gather_batch([z_image, z_text, index_list])
        batch_size = 5

        val_loader1 = torch.utils.data.DataLoader(
            valid_dataset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True, sampler=val_sampler, drop_last=False)
       
        for index, images, texts in val_loader1:
            images, texts, index = images.cuda(args.gpu, non_blocking=True), texts.cuda(args.gpu, non_blocking=True), index.cuda(args.gpu, non_blocking=True)
            image_embed, image_weight = model.encode_image(images)
            text_embed, text_weight = model.encode_text(texts)
            if criterion.enable_linear_kernel:
                z_image, z_text = criterion.comb_feature(image_embed, image_weight, text_embed, text_weight)
            else:
                z_image, z_text = criterion.rff_trick.forward_with_weights(image_embed, image_weight, text_embed, text_weight)

            logits_per_image, logits_per_text = z_image @ z_text_all.t(), z_text @ z_image_all.t() 
            
            # identify true label under multi-gpu
            expanded_vector = index_list_all.unsqueeze(1)
            expanded_targets = index.unsqueeze(0)
            matches = (expanded_vector == expanded_targets)
            indices = torch.arange(len(index_list_all), device=index_list_all.device).unsqueeze(1).expand_as(matches)
            masked_indices = torch.where(matches, indices, torch.tensor(len(index_list_all), device=index_list_all.device))
            first_matches, _ = torch.min(masked_indices, dim=0)
            labels = torch.where(first_matches < len(index_list_all), first_matches, torch.tensor(-1, device=index_list_all.device))
            
            acc_list = accuracy(logits_per_image, labels, topk=topk)
            acc_list = utils.scaled_all_reduce(acc_list)
            for i, k in enumerate(topk):
                i2t_result[str(k)].update(acc_list[i].item(), images.size(0))

            acc_list = accuracy(logits_per_text, labels, topk=topk)
            acc_list = utils.scaled_all_reduce(acc_list)
            for i, k in enumerate(topk):
                t2i_result[str(k)].update(acc_list[i].item(), texts.size(0))

    i2t_result_dict = {f"r{k}": i2t_result[str(k)].avg for k in topk}
    t2i_result_dict = {f"r{k}": t2i_result[str(k)].avg for k in topk}

    return i2t_result_dict, t2i_result_dict


def validate_retrieval_kme_clip(valid_dataset, val_sampler, model, criterion, topk, args):
    i2t_result, t2i_result = {}, {}
    for k in topk:
        i2t_result[str(k)] = AverageMeter(f"Acc@{k}", ":6.2f")
        t2i_result[str(k)] = AverageMeter(f"Acc@{k}", ":6.2f")

    valid_dataset = IndexTrackingDataset(valid_dataset)
    val_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True, sampler=val_sampler, drop_last=False)
    with torch.no_grad():
        model = utils.get_model(model)
        sigma_inv = model.sigma_inv
        image_list = []
        text_list = []
        index_list = []
        for i, images, texts in val_loader:
            images = images.cuda(args.gpu, non_blocking=True)
            image_list.append(images)
            texts = texts.cuda(args.gpu, non_blocking=True)
            text_list.append(texts)
            index_list.append(i)

        image_list, text_list = torch.cat(image_list, dim=0).to(args.gpu), torch.cat(text_list, dim=0).to(args.gpu)
        index_list = torch.cat(index_list, dim=0).to(args.gpu)
        image_embed, image_weight = model.encode_image(image_list)
        image_weight = F.softplus(image_weight)
        text_embed, text_weight = model.encode_text(text_list)
        text_weight = F.softplus(text_weight)

        image_embed_all, image_weight_all, text_embed_all, text_weight_all, index_list_all = utils.all_gather_batch([image_embed, image_weight, text_embed, text_weight, index_list])
        batch_size = 5

        val_loader1 = torch.utils.data.DataLoader(
            valid_dataset, batch_size=batch_size, shuffle=False,
            num_workers=0, pin_memory=True, sampler=val_sampler, drop_last=False)
       
        for index, images, texts in val_loader1:
            images, texts, index = images.cuda(args.gpu, non_blocking=True), texts.cuda(args.gpu, non_blocking=True), index.cuda(args.gpu, non_blocking=True)
            
            image_embed, image_weight = model.encode_image(images)
            image_weight = F.softplus(image_weight)
            text_embed, text_weight = model.encode_text(texts)
            text_weight = F.softplus(text_weight)

            logits_per_image, logits_per_text1 = criterion.calc_weighted_kernel(image_embed, image_weight, text_embed_all, text_weight_all, sigma_inv)
            logits_per_text, logits_per_image1 = criterion.calc_weighted_kernel(text_embed, text_weight, image_embed_all, image_weight_all, sigma_inv)

            # identify true label under multi-gpu
            expanded_vector = index_list_all.unsqueeze(1)
            expanded_targets = index.unsqueeze(0)
            matches = (expanded_vector == expanded_targets)
            indices = torch.arange(len(index_list_all), device=index_list_all.device).unsqueeze(1).expand_as(matches)
            masked_indices = torch.where(matches, indices, torch.tensor(len(index_list_all), device=index_list_all.device))
            first_matches, _ = torch.min(masked_indices, dim=0)
            labels = torch.where(first_matches < len(index_list_all), first_matches, torch.tensor(-1, device=index_list_all.device))
            
            acc_list = accuracy(logits_per_image, labels, topk=topk)
            acc_list = utils.scaled_all_reduce(acc_list)
            for i, k in enumerate(topk):
                i2t_result[str(k)].update(acc_list[i].item(), images.size(0))

            acc_list = accuracy(logits_per_text, labels, topk=topk)
            acc_list = utils.scaled_all_reduce(acc_list)
            for i, k in enumerate(topk):
                t2i_result[str(k)].update(acc_list[i].item(), texts.size(0))

    i2t_result_dict = {f"r{k}": i2t_result[str(k)].avg for k in topk}
    t2i_result_dict = {f"r{k}": t2i_result[str(k)].avg for k in topk}

    return i2t_result_dict, t2i_result_dict


class IndexTrackingDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset):
        self.dataset = original_dataset
        
    def __getitem__(self, idx):
        # 元のデータとともにインデックスも返す
        data = self.dataset[idx]
        return (idx, *data) if isinstance(data, tuple) else (idx, data)
        
    def __len__(self):
        return len(self.dataset)


class CocoDataset(Dataset):
    def __init__(self, root_dir, annotation_file, tokenizer, transform, index):
        """
        Args:
            root_dir (string): path of the directory where images are saved
            annotation_file (string): path of the json file for annotations
        """
        self.root_dir = root_dir
        self.transform = transform
        self.coco = COCO(annotation_file)
        self.ids = list(self.coco.imgs.keys())
        self.tokenizer = tokenizer
        self.index = index
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        img_id = self.ids[idx]
        
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=[], iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)
        captions = [ann['caption'] for ann in anns]
        
        if self.transform:
            image = self.transform(image)

        caption_idx = self.index
        caption = captions[caption_idx]
        caption = self.tokenizer(caption)
        
        return image, caption


class Flickr30kDatasetSplit(Dataset):
    def __init__(self, root_dir, json_file, split, tokenizer, transform, index, max_captions=5):
        """
        Args:
            root_dir (string): path of the directory where images are saved
            annotation_file (string): path of the json file for annotations
        """

        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_captions = max_captions
        self.split = split
        self.index = index
        
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        
        self.image_caption_pairs = []
        
        for img_data in self.data['images']:
            if img_data['split'] == split:
                image_path = os.path.join(self.root_dir, img_data['filename'])
                
                if not os.path.exists(image_path):
                    continue
                
                captions = [sent['raw'] for sent in img_data['sentences'][:max_captions]]
                caption = captions[self.index]
                self.image_caption_pairs.append((image_path, caption))
    
    def __len__(self):
        return len(self.image_caption_pairs)
    
    def __getitem__(self, idx):
        img_path, caption = self.image_caption_pairs[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        if self.tokenizer:
            caption = self.tokenizer(caption)
            
        return image, caption


if __name__ == "__main__":
    parser = argparse.ArgumentParser("retrieval evaluations", parents=[get_args_parser()])
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    main(args)
