"""
 * Copyright (c) 2022, salesforce.com, inc.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 * By Junnan Li
"""
import argparse
import datetime
import json
import os
import random
import time
from collections import OrderedDict
from pathlib import Path

import numpy as np
import ruamel.yaml as yaml
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F
import wandb

import utils
from data import create_dataset, create_loader, create_sampler
from models.blip_cir import blip_cir
from utils import cosine_lr_schedule


def train(model, data_loader, optimizer, epoch, device, config):
    # train
    model.train()

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of training params: {n_parameters/1_000_000:,.1f}M")

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
    metric_logger.add_meter(
        "loss", utils.SmoothedValue(window_size=1, fmt="{value:.4f}")
    )
    header = "Train Epoch: [{}]".format(epoch)
    print_freq = 50

    for i, (ref_img, tar_img, caption, idx) in enumerate(
        metric_logger.log_every(data_loader, print_freq, header)
    ):
        ref_img = ref_img.to(device, non_blocking=True)
        tar_img = tar_img.to(device, non_blocking=True)

        idx = idx.to(device, non_blocking=True)

        if epoch > 0:
            alpha = config["alpha"]
        else:
            alpha = config["alpha"] * min(1, i / len(data_loader))

        if (
            config["wandb"]
            and config["wandb_log_train"]
            and i == 0
            and utils.is_main_process()
        ):
            wandb.log(
                {
                    "Train-examples": [
                        wandb.Image(
                            torch.cat([r_img, t_img], dim=2), caption=f"{idx}: {cap}"
                        )
                        for idx, (r_img, t_img, cap) in enumerate(
                            zip(ref_img, tar_img, caption[:64])
                        )
                    ]
                }
            )

        loss = model(ref_img, tar_img, caption, alpha=alpha, idx=idx)

        if config["wandb"] and utils.is_main_process():
            wandb.log({"Loss": loss.item()})
            wandb.log({"Alpha": alpha})
            wandb.log({"Epoch": epoch})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        if len(data_loader) > 10_000 and i % 500 == 0:
            print("Saving mid checkpoint...")
            utils.save_on_master(
                {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "config": config,
                },
                os.path.join(args.output_dir, f"mid-checkpoint-{str(i).zfill(4)}.pth"),
            )

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())
    return {
        k: "{:.3f}".format(meter.global_avg)
        for k, meter in metric_logger.meters.items()
    }


@torch.no_grad()
def evaluation(model, data_loader, device, config):
    model.eval()

    print("Computing features for evaluation...")
    start_time = time.time()

    tar_img_feats = []
    query_feats = []
    ref_imgs = []
    tar_imgs = []
    pair_ids = []
    captions = []
    for ref_img, tar_img, caption, pair_id in data_loader:
        pair_ids.extend(pair_id.numpy().tolist())
        ref_imgs.append(ref_img)
        tar_imgs.append(tar_img)
        captions.extend(caption)

        ref_img = ref_img.to(device)
        tar_img = tar_img.to(device)

        ref_img_embs = model.visual_encoder(ref_img)
        ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device)

        text = model.tokenizer(
            caption,
            padding="longest",
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)

        # Shift encoder
        encoder_input_ids = text.input_ids.clone()
        encoder_input_ids[:, 0] = model.tokenizer.enc_token_id
        query_embs = model.text_encoder(
            encoder_input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=ref_img_embs,
            encoder_attention_mask=ref_img_atts,
            return_dict=True,
        )
        query_feat = query_embs.last_hidden_state[:, 0, :]
        query_feat = F.normalize(model.text_proj(query_feat), dim=-1)
        query_feats.append(query_feat.cpu())

        # Encode the target image
        tar_img_embs = model.visual_encoder(tar_img)
        tar_img_feat = F.normalize(model.vision_proj(tar_img_embs[:, 0, :]), dim=-1)
        tar_img_feats.append(tar_img_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)
    tar_img_feats = torch.cat(tar_img_feats, dim=0)

    ref_img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]
    tar_img_ids = [data_loader.dataset.pairid2tar[pair_id] for pair_id in pair_ids]

    sim_t2q = (tar_img_feats @ query_feats.t()).cpu().numpy()
    sim_q2t = (query_feats @ tar_img_feats.t()).cpu().numpy()

    # Add zeros where ref_img_id == tar_img_id
    for i in range(len(ref_img_ids)):
        for j in range(len(tar_img_ids)):
            if ref_img_ids[i] == tar_img_ids[j]:
                sim_t2q[j][i] = -10
                sim_q2t[i][j] = -10

    # Log examples
    if utils.is_main_process() and config["wandb"]:
        ref_imgs = torch.cat(ref_imgs, dim=0)
        tar_imgs = torch.cat(tar_imgs, dim=0)

        wandb_img = []
        for index, score in enumerate(sim_t2q[:10]):
            pred_idx = np.argsort(score)[::-1][0]
            wandb_img.append(
                wandb.Image(
                    torch.cat(
                        [ref_imgs[index], tar_imgs[pred_idx], tar_imgs[index]], dim=2
                    ),
                    caption=f"{index}: {captions[index]}",
                )
            )
        wandb.log({"Val-predictions": wandb_img})

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    return sim_q2t, sim_t2q


@torch.no_grad()
def itm_eval(scores_q2t):
    # Query->Target
    ranks = np.zeros(scores_q2t.shape[0])

    for index, score in enumerate(scores_q2t):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == index)[0][0]

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    tr_mean = (tr1 + tr5 + tr10) / 3

    eval_result = {
        "tar_r1": tr1,
        "tar_r5": tr5,
        "tar_r10": tr10,
        "tar_r_mean": tr_mean,
        "r_mean": tr_mean,
    }
    return eval_result


@torch.no_grad()
def test(model, data_loader, device, config):
    model.eval()

    print("Computing features for test...")
    start_time = time.time()

    tar_img_feats = []
    query_feats = []
    pair_ids = []
    for ref_img, caption, pair_id in data_loader:
        pair_ids.extend(pair_id.numpy().tolist())

        ref_img = ref_img.to(device)

        ref_img_embs = model.visual_encoder(ref_img)
        ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to(device)

        text = model.tokenizer(
            caption,
            padding="longest",
            truncation=True,
            max_length=64,
            return_tensors="pt",
        ).to(device)

        # Shift encoder
        encoder_input_ids = text.input_ids.clone()
        encoder_input_ids[:, 0] = model.tokenizer.enc_token_id
        query_embs = model.text_encoder(
            encoder_input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=ref_img_embs,
            encoder_attention_mask=ref_img_atts,
            return_dict=True,
        )
        query_feat = query_embs.last_hidden_state[:, 0, :]
        query_feat = F.normalize(model.text_proj(query_feat), dim=-1)
        query_feats.append(query_feat.cpu())

        # Encode the target image
        tar_img_feat = F.normalize(model.vision_proj(ref_img_embs[:, 0, :]), dim=-1)
        tar_img_feats.append(tar_img_feat.cpu())

    query_feats = torch.cat(query_feats, dim=0)
    tar_img_feats = torch.cat(tar_img_feats, dim=0)
    assert len(query_feats) == len(tar_img_feats) == len(pair_ids)
    img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids]

    id2emb = OrderedDict()
    for img_id, tar_img_feat in zip(img_ids, tar_img_feats):
        if img_id not in id2emb:
            id2emb[img_id] = tar_img_feat

    sims_q2t = np.zeros((len(query_feats), len(id2emb)))
    for i, (pair_id, query_feat) in enumerate(zip(pair_ids, query_feats)):
        for j, (tar_id, tar_feat) in enumerate(id2emb.items()):
            que_id = data_loader.dataset.pairid2ref[pair_id]
            if que_id == tar_id:
                sims_q2t[i, j] = -100
            else:
                sims_q2t[i, j] = query_feat @ tar_feat.T

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluation time {}".format(total_time_str))

    recalls = {}
    recalls["version"] = "rc2"
    recalls["metric"] = "recall"

    recalls_subset = {}
    recalls_subset["version"] = "rc2"
    recalls_subset["metric"] = "recall_subset"

    target_imgs = np.array(list(id2emb.keys()))

    assert len(sims_q2t) == len(pair_ids)
    for i, (pair_id, query_sims) in enumerate(zip(pair_ids, sims_q2t)):
        sorted_indices = np.argsort(query_sims)[::-1]

        query_id_recalls = list(target_imgs[sorted_indices][:50])
        query_id_recalls = [str(x) for x in query_id_recalls]
        recalls[str(pair_id)] = query_id_recalls

        members = data_loader.dataset.pairid2members[pair_id]
        query_id_recalls_subset = [
            target for target in target_imgs[sorted_indices] if target in members
        ]
        query_id_recalls_subset = [str(x) for x in query_id_recalls_subset][:3]
        recalls_subset[str(pair_id)] = query_id_recalls_subset

    return recalls, recalls_subset


def main(args, config):
    utils.init_distributed_mode(args)

    device = torch.device(args.device)

    # wandb
    if config["wandb"] and utils.is_main_process():
        config_name = Path(args.config).stem
        wandb.init(
            project=config["wandb_project"],
            config=config,
            dir=args.output_dir,
            name=config_name,
        )

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    #### Dataset ####
    print("Creating retrieval dataset")
    train_dataset, val_dataset, test_dataset = create_dataset(config["dataset"], config)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [
            None,
            None,
        ]
    else:
        samplers = [None, None, None]

    train_loader, val_loader, test_loader = create_loader(
        [train_dataset, val_dataset, test_dataset],
        samplers,
        batch_size=[config["batch_size_train"]] + [config["batch_size_test"]] * 2,
        num_workers=[4, 4, 4],
        is_trains=[True, False, False],
        collate_fns=[None, None, None],
    )

    #### Model ####
    print("Creating model")
    model = blip_cir(
        pretrained=config["pretrained"],
        image_size=config["image_size"],
        vit=config["vit"],
        vit_grad_ckpt=config["vit_grad_ckpt"],
        vit_ckpt_layer=config["vit_ckpt_layer"],
        queue_size=config["queue_size"],
        negative_all_rank=config["negative_all_rank"],
    )

    model = model.to(device)

    model_without_ddp = model
    if args.distributed:
        # unused_parameters = torch.distributed.get_rank() == 0
        # print("unused_parameters", unused_parameters)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        # find_unused_parameters=unused_parameters,
        # static_graph=True,
        model_without_ddp = model.module

    optimizer = torch.optim.AdamW(
        params=model.parameters(),
        lr=config["init_lr"],
        weight_decay=config["weight_decay"],
    )

    best = 0
    best_epoch = 0

    if utils.is_main_process() and args.test:
        recalls, recalls_subset = test(model_without_ddp, test_loader, device, config)

        with open(os.path.join(args.output_dir, "recalls.json"), "w") as f:
            json.dump(recalls, f)

        with open(os.path.join(args.output_dir, "recalls_subset.json"), "w") as f:
            json.dump(recalls_subset, f)

    print("Start training")
    start_time = time.time()

    for epoch in range(0, config["max_epoch"]):
        if args.test:
            break

        if not args.evaluate:
            if args.distributed:
                train_loader.sampler.set_epoch(epoch)

            cosine_lr_schedule(
                optimizer,
                epoch,
                config["max_epoch"],
                config["init_lr"],
                config["min_lr"],
            )

            train_stats = train(model, train_loader, optimizer, epoch, device, config)

        if utils.is_main_process():
            score_val_q2t, _ = evaluation(model_without_ddp, val_loader, device, config)

        if utils.is_main_process():
            val_result = itm_eval(score_val_q2t)
            print(val_result)

            if config["wandb"]:
                wandb.log({"val/val_r1": val_result["tar_r1"]})
                wandb.log({"val/val_r5": val_result["tar_r5"]})
                wandb.log({"val/val_r10": val_result["tar_r10"]})
                wandb.log({"val/val_r_mean": val_result["r_mean"]})

            if val_result["r_mean"] >= best:
                save_obj = {
                    "model": model_without_ddp.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "config": config,
                    "epoch": epoch,
                }
                torch.save(
                    save_obj, os.path.join(args.output_dir, "checkpoint_best.pth")
                )
                best = val_result["r_mean"]
                best_epoch = epoch

            if args.evaluate:
                log_stats = {
                    **{f"val_{k}": v for k, v in val_result.items()},
                }
                log_stats["time"] = str(
                    datetime.timedelta(seconds=int(time.time() - start_time))
                )
                with open(os.path.join(args.output_dir, "evaluate.txt"), "a") as f:
                    f.write(json.dumps(log_stats) + "\n")
            else:
                log_stats = {
                    **{f"train_{k}": v for k, v in train_stats.items()},
                    **{f"val_{k}": v for k, v in val_result.items()},
                    "epoch": epoch,
                    "best_epoch": best_epoch,
                }
                log_stats["time"] = str(
                    datetime.timedelta(seconds=int(time.time() - start_time))
                )
                with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
                    f.write(json.dumps(log_stats) + "\n")

        if args.evaluate:
            break

        if args.distributed:
            dist.barrier()
        torch.cuda.empty_cache()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/retrieval_flickr.yaml")
    parser.add_argument("--output_dir", default="output/Retrieval_flickr")
    parser.add_argument("--evaluate", action="store_true")
    parser.add_argument("--test", action="store_true")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument(
        "--world_size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument(
        "--dist_url", default="env://", help="url used to set up distributed training"
    )
    parser.add_argument("--distributed", default=True, type=bool)
    args = parser.parse_args()

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)

    if not args.test:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
        yaml.dump(
            config,
            open(os.path.join(args.output_dir, "config.yaml"), "w"),
            default_flow_style=False,
        )
    else:
        if not Path(args.output_dir).exists():
            output_dir = Path(args.output_dir)
            output_dir = output_dir.parent / "zero-shot" / output_dir.stem
            args.output_dir = str(output_dir)
            Path(args.output_dir).mkdir(parents=True, exist_ok=True)
        config["wandb"] = False
        pretrained_pth = os.path.join(args.output_dir, "checkpoint_best.pth")
        if os.path.exists(pretrained_pth):
            config["pretrained"] = pretrained_pth

        pretrained_pth = Path(args.config).parent / "checkpoint_best.pth"
        if pretrained_pth.exists():
            config["pretrained"] = str(pretrained_pth)
            args.output_dir = str(pretrained_pth.parent)

    main(args, config)

    if config["wandb"] and utils.is_main_process():
        wandb.finish()
