"""
 * Copyright (c) 2022, salesforce.com, inc.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 * By Junnan Li
"""
import argparse
import datetime
import json
import os
import random
import time
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List

import numpy as np
import ruamel.yaml as yaml
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm

import utils
from data import create_dataset, create_loader, create_sampler, create_test_dataset
from models.blip_cir_embs import blip_cir_embs
from utils import cosine_lr_schedule


@torch.no_grad()
def evaluation_image(model, data_loader, device, config):
    model.eval()

    print("Computing features for evaluation...")
    start_time = time.time()

    tar_img_feats = []
    query_feats = []
    ref_imgs = []
    pair_ids = []
    for ref_img, tar_feat, _, pair_id, *_ in data_loader:
        pair_ids.extend(pair_id.numpy().tolist())
        ref_imgs.append(ref_img)

        ref_img = ref_img.to(device)

        ref_img_embs = model.visual_encoder(ref_img)
        image_feat = F.normalize(model.vision_proj(ref_img_embs[:, 0, :]), dim=-1)
        query_feats.append(image_feat.cpu())

        # Encode the target image
        tar_img_feats.append(tar_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)
    tar_img_feats = torch.cat(tar_img_feats, dim=0)

    query_feats = F.normalize(query_feats, dim=-1)
    tar_img_feats = F.normalize(tar_img_feats, dim=-1)

    ref_img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]
    tar_img_ids = [data_loader.dataset.pairid2tar[pair_id] for pair_id in pair_ids]

    sim_t2q = (tar_img_feats @ query_feats.t()).cpu().numpy()
    sim_q2t = (query_feats @ tar_img_feats.t()).cpu().numpy()

    # Add zeros where ref_img_id == tar_img_id
    for i in range(len(ref_img_ids)):
        for j in range(len(tar_img_ids)):
            if ref_img_ids[i] == tar_img_ids[j]:
                sim_t2q[j][i] = -10
                sim_q2t[i][j] = -10

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    return sim_q2t, sim_t2q


@torch.no_grad()
def evaluation_text(model, data_loader, device, config):
    model.eval()

    print("Computing features for evaluation...")
    start_time = time.time()

    query_feats = []
    tar_img_feats = []
    captions = []
    pair_ids = []
    for _, tar_feat, caption, pair_id, *_ in data_loader:
        pair_ids.extend(pair_id.numpy().tolist())
        captions.extend(caption)

        text = model.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=35,
            return_tensors="pt",
        ).to(device)

        # encoder_input_ids[:, 0] = model.tokenizer.enc_token_id
        text_output = model.text_encoder(
            text.input_ids,
            attention_mask=text.attention_mask,
            return_dict=True,
            mode="text",
        )
        query_feat = F.normalize(
            model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
        )
        query_feats.append(query_feat.cpu())

        # Encode the target image
        tar_img_feats.append(tar_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)
    tar_img_feats = torch.cat(tar_img_feats, dim=0)

    query_feats = F.normalize(query_feats, dim=-1)
    tar_img_feats = F.normalize(tar_img_feats, dim=-1)

    ref_img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]
    tar_img_ids = [data_loader.dataset.pairid2tar[pair_id] for pair_id in pair_ids]

    sim_t2q = (tar_img_feats @ query_feats.t()).cpu().numpy()
    sim_q2t = (query_feats @ tar_img_feats.t()).cpu().numpy()

    # Add zeros where ref_img_id == tar_img_id
    for i in range(len(ref_img_ids)):
        for j in range(len(tar_img_ids)):
            if ref_img_ids[i] == tar_img_ids[j]:
                sim_t2q[j][i] = -10
                sim_q2t[i][j] = -10

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    return sim_q2t, sim_t2q


@torch.no_grad()
def evaluation(model, data_loader, device, config):
    model.eval()

    print("Computing features for evaluation...")
    start_time = time.time()

    tar_img_feats = []
    query_feats = []
    ref_imgs = []
    captions = []
    pair_ids = []
    for ref_img, tar_feat, caption, pair_id, *_ in data_loader:
        pair_ids.extend(pair_id.numpy().tolist())
        ref_imgs.append(ref_img)
        captions.extend(caption)

        ref_img = ref_img.to(device)

        ref_img_embs = model.visual_encoder(ref_img)
        ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device)

        text = model.tokenizer(
            caption,
            padding="longest",
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)

        # Shift encoder
        encoder_input_ids = text.input_ids.clone()
        encoder_input_ids[:, 0] = model.tokenizer.enc_token_id
        query_embs = model.text_encoder(
            encoder_input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=ref_img_embs,
            encoder_attention_mask=ref_img_atts,
            return_dict=True,
        )
        query_feat = query_embs.last_hidden_state[:, 0, :]
        query_feat = F.normalize(model.text_proj(query_feat), dim=-1)
        query_feats.append(query_feat.cpu())

        # Encode the target image
        tar_img_feats.append(tar_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)
    tar_img_feats = torch.cat(tar_img_feats, dim=0)

    query_feats = F.normalize(query_feats, dim=-1)
    tar_img_feats = F.normalize(tar_img_feats, dim=-1)

    ref_img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]
    tar_img_ids = [data_loader.dataset.pairid2tar[pair_id] for pair_id in pair_ids]

    sim_t2q = (tar_img_feats @ query_feats.t()).cpu().numpy()
    sim_q2t = (query_feats @ tar_img_feats.t()).cpu().numpy()

    # Add zeros where ref_img_id == tar_img_id
    for i in range(len(ref_img_ids)):
        for j in range(len(tar_img_ids)):
            if ref_img_ids[i] == tar_img_ids[j]:
                sim_t2q[j][i] = -10
                sim_q2t[i][j] = -10

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    return sim_q2t, sim_t2q


@torch.no_grad()
def eval_recall(scores_q2t):
    # Query->Target
    ranks = np.zeros(scores_q2t.shape[0])

    for index, score in enumerate(scores_q2t):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == index)[0][0]

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)  # type: ignore
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    tr50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks)

    tr_mean = (tr1 + tr5 + tr10) / 3

    eval_result = {
        "R1": round(tr1, 2),
        "R5": round(tr5, 2),
        "R10": round(tr10, 2),
        "R50": round(tr50, 2),
        "r_mean": round(tr_mean, 2),
    }
    return eval_result


@torch.no_grad()
def evaluation_fashioniq(model, data_loader, device, config):
    model.eval()

    print("Computing features for evaluation...")
    start_time = time.time()

    query_feats = []
    ref_imgs = []
    captions = []
    pair_ids = []
    for batch in tqdm(data_loader):
        ref_img, caption, pair_id = batch
        pair_ids.extend(pair_id.numpy().tolist())
        ref_imgs.append(ref_img)
        captions.extend(caption)

        ref_img = ref_img.to(device)

        ref_img_embs = model.visual_encoder(ref_img)
        ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device)

        text = model.tokenizer(
            caption,
            padding="longest",
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)

        # Shift encoder
        encoder_input_ids = text.input_ids.clone()
        encoder_input_ids[:, 0] = model.tokenizer.enc_token_id
        query_embs = model.text_encoder(
            encoder_input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=ref_img_embs,
            encoder_attention_mask=ref_img_atts,
            return_dict=True,
        )
        query_feat = query_embs.last_hidden_state[:, 0, :]
        query_feat = F.normalize(model.text_proj(query_feat), dim=-1)
        query_feats.append(query_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)

    tar_img_feats = []
    tar_img_ids = []
    for target_id in data_loader.dataset.target_ids:
        tar_img_ids.append(target_id)
        target_emb_pth = data_loader.dataset.id2embpth[target_id]
        target_feat = torch.load(target_emb_pth).cpu()
        tar_img_feats.append(target_feat.cpu())
    tar_img_feats = torch.stack(tar_img_feats)
    tar_img_feats = F.normalize(tar_img_feats, dim=-1)

    # query_feats = F.normalize(query_feats, dim=-1)
    # tar_img_feats = F.normalize(tar_img_feats, dim=-1)

    ref_img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]

    sim_q2t = (query_feats @ tar_img_feats.t()).cpu()

    # Add zeros where ref_img_id == tar_img_id
    for i in range(len(ref_img_ids)):
        for j in range(len(tar_img_ids)):
            if ref_img_ids[i] == tar_img_ids[j]:
                sim_q2t[i][j] = -10

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    ref_img_ids = np.array(ref_img_ids)
    tar_img_ids = np.array(tar_img_ids)

    correct_img_ids = [data_loader.dataset.pairid2tar[pair_id] for pair_id in pair_ids]

    return get_recalls_labels(sim_q2t, correct_img_ids, tar_img_ids)


# From google-research/composed_image_retrieval
def recall_at_k_labels(sim, query_lbls, target_lbls, k=10):
    distances = 1 - sim
    sorted_indices = torch.argsort(distances, dim=-1).cpu()
    sorted_index_names = np.array(target_lbls)[sorted_indices]
    labels = torch.tensor(
        sorted_index_names
        == np.repeat(np.array(query_lbls), len(target_lbls)).reshape(
            len(query_lbls), -1
        )
    )
    assert torch.equal(
        torch.sum(labels, dim=-1).int(), torch.ones(len(query_lbls)).int()
    )
    return round((torch.sum(labels[:, :k]) / len(labels)).item() * 100, 2)


def get_recalls_labels(
    sims, query_lbls, target_lbls, ks: List[int] = [1, 5, 10, 50]
) -> Dict[str, float]:
    return {f"R{k}": recall_at_k_labels(sims, query_lbls, target_lbls, k) for k in ks}


import numpy as np


def recall_at_k(similarity_matrix, query_lbls, target_lbls, k):
    num_queries = similarity_matrix.shape[1]
    recalls = []

    for i in range(num_queries):
        query_lbl = query_lbls[i]
        query_scores = similarity_matrix[:, i]
        sorted_indices = np.argsort(query_scores)[::-1][:k]
        retrieved_lbls = target_lbls[sorted_indices]

        if query_lbl in retrieved_lbls:
            recall = 1.0
        else:
            recall = 0.0

        recalls.append(recall)

    mean_recall = np.mean(recalls)
    return mean_recall


@torch.no_grad()
def test_cirr(model, data_loader, device, config):
    model.eval()

    print("Computing features for test...")
    start_time = time.time()

    tar_img_feats = []
    query_feats = []
    pair_ids = []
    for ref_img, tar_feat, caption, pair_id, *_ in data_loader:
        pair_ids.extend(pair_id.numpy().tolist())

        ref_img = ref_img.to(device)

        ref_img_embs = model.visual_encoder(ref_img)
        ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device)

        text = model.tokenizer(
            caption,
            padding="longest",
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)

        # Shift encoder
        encoder_input_ids = text.input_ids.clone()
        encoder_input_ids[:, 0] = model.tokenizer.enc_token_id
        query_embs = model.text_encoder(
            encoder_input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=ref_img_embs,
            encoder_attention_mask=ref_img_atts,
            return_dict=True,
        )
        query_feat = query_embs.last_hidden_state[:, 0, :]
        query_feat = F.normalize(model.text_proj(query_feat), dim=-1)
        query_feats.append(query_feat.cpu())

        # Encode the target image
        tar_img_feats.append(tar_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)
    tar_img_feats = torch.cat(tar_img_feats, dim=0)
    assert len(query_feats) == len(pair_ids)
    img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]

    id2emb = OrderedDict()
    for img_id, tar_img_feat in zip(img_ids, tar_img_feats):
        if img_id not in id2emb:
            id2emb[img_id] = tar_img_feat

    sims_q2t = np.zeros((len(query_feats), len(id2emb)))
    for i, (pair_id, query_feat) in enumerate(zip(pair_ids, query_feats)):
        for j, (tar_id, tar_feat) in enumerate(id2emb.items()):
            que_id = data_loader.dataset.pairid2ref[pair_id]
            if que_id == tar_id:
                sims_q2t[i, j] = -100
            else:
                sims_q2t[i, j] = query_feat @ tar_feat.T

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    recalls = {}
    recalls["version"] = "rc2"
    recalls["metric"] = "recall"

    recalls_subset = {}
    recalls_subset["version"] = "rc2"
    recalls_subset["metric"] = "recall_subset"

    target_imgs = np.array(list(id2emb.keys()))

    assert len(sims_q2t) == len(pair_ids)
    for i, (pair_id, query_sims) in enumerate(zip(pair_ids, sims_q2t)):
        sorted_indices = np.argsort(query_sims)[::-1]

        query_id_recalls = list(target_imgs[sorted_indices][:50])
        query_id_recalls = [str(x) for x in query_id_recalls]
        recalls[str(pair_id)] = query_id_recalls

        members = data_loader.dataset.pairid2members[pair_id]
        query_id_recalls_subset = [
            target for target in target_imgs[sorted_indices] if target in members
        ]
        query_id_recalls_subset = [str(x) for x in query_id_recalls_subset][:3]
        recalls_subset[str(pair_id)] = query_id_recalls_subset

    return recalls, recalls_subset


@torch.no_grad()
def main_test(args, config):
    utils.init_distributed_mode(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    #### Dataset ####
    print("Creating retrieval dataset")
    test_dataset = create_test_dataset(f"{args.test_dataset}_embs", config)

    test_loader = create_loader(
        [test_dataset],
        samplers=[None],
        batch_size=[config["batch_size_test"]],
        num_workers=[4],
        is_trains=[False],
        collate_fns=[None],
    )

    #### Model ####
    print("Creating model")
    beta = config["beta"] if "beta" in config else 0.0
    model = blip_cir_embs(
        pretrained=config["pretrained"],
        image_size=config["image_size"],
        vit=config["vit"],
        vit_grad_ckpt=config["vit_grad_ckpt"],
        vit_ckpt_layer=config["vit_ckpt_layer"],
        queue_size=config["queue_size"],
        negative_all_rank=config["negative_all_rank"],
        train_vit=config["train_vit"],
        beta=beta,
    )

    model = model.to(device)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of training params: {n_parameters/1_000_000:,.1f}M")

    if args.test_dataset == "cirr":
        recalls, recalls_subset = test_cirr(model, test_loader[0], device, config)
        with open(
            os.path.join(args.output_dir, f"recalls-{args.checkpoint}.json"), "w"
        ) as f:
            json.dump(recalls, f)

        with open(
            os.path.join(args.output_dir, f"recalls_subset-{args.checkpoint}.json"),
            "w",
        ) as f:
            json.dump(recalls_subset, f)

    elif args.test_dataset == "webvid":
        if args.type == "both":
            score_val_q2t, _ = evaluation(model, test_loader[0], device, config)
            recalls = eval_recall(score_val_q2t)
            recall_path = os.path.join(
                args.output_dir, f"recalls-manual-{args.checkpoint}.json"
            )
        elif args.type == "text":
            score_val_q2t, _ = evaluation_text(model, test_loader[0], device, config)
            recalls = eval_recall(score_val_q2t)
            recall_path = os.path.join(
                args.output_dir, f"recalls-manual-text-{args.checkpoint}.json"
            )
        elif args.type == "image":
            score_val_q2t, _ = evaluation_image(model, test_loader[0], device, config)
            recalls = eval_recall(score_val_q2t)
            recall_path = os.path.join(
                args.output_dir, f"recalls-manual-image-{args.checkpoint}.json"
            )
        else:
            raise NotImplementedError

        print(recalls)

        with open(recall_path, "w") as f:
            json.dump(recalls, f)

    elif args.test_dataset == "fashioniq":
        assert args.type == "both", "Only both is supported for fashioniq"
        recalls = evaluation_fashioniq(model, test_loader[0], device, config)
        print(f"{config['data']}: {recalls['R10'], recalls['R50']}")

        recall_path = os.path.join(
            args.output_dir,
            f"recalls-fashioniq-{config['data']}-{args.checkpoint}.json",
        )

        print(f"R10 {config['data']}", np.mean(recalls["R10"]).round(2))
        print(f"R50 {config['data']}", np.mean(recalls["R50"]).round(2))
        print(recalls)

        with open(recall_path, "w") as f:
            json.dump(recalls, f)

    else:
        raise NotImplementedError

    return recalls


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/retrieval_flickr.yaml")
    parser.add_argument("--output_dir", default="output/Retrieval_flickr")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--checkpoint", default="best", type=str)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument(
        "--test_dataset",
        default="cirr",
        type=str,
        choices=["cirr", "webvid", "fashioniq"],
    )
    parser.add_argument("--type", choices=["both", "text", "image"], default="both")
    args = parser.parse_args()

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)

    pretrained_pth = os.path.join(args.output_dir, f"checkpoint_{args.checkpoint}.pth")
    if os.path.exists(pretrained_pth):
        config["pretrained"] = pretrained_pth

    pretrained_pth = Path(args.config).parent / f"checkpoint_{args.checkpoint}.pth"
    if pretrained_pth.exists():
        config["pretrained"] = str(pretrained_pth)
        args.output_dir = str(pretrained_pth.parent)
    else:
        output_dir = Path(args.output_dir)
        output_dir.mkdir(exist_ok=True)
        args.output_dir = str(output_dir)

    print(f"Loading pretrained model from {config['pretrained']}")

    categories = ["shirt", "dress", "toptee"]
    if args.test_dataset == "fashioniq" and config["data"] not in categories:
        recalls_10 = []
        recalls_50 = []
        for data in categories:
            config["data"] = data
            print(f"Testing on {data}")
            recalls = main_test(args, config)
            recalls_10.append(recalls["R10"])
            recalls_50.append(recalls["R50"])

        print("MEAN R10:", np.mean(recalls_10).round(2))
        print("MEAN R50:", np.mean(recalls_50).round(2))

    else:
        main_test(args, config)
