from pathlib import Path

import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

from data.embs_dataset import VideoDataset
from models.blip_cir import blip_cir

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_blip_config(model="base"):
    config = dict()
    if model == "base":
        config[
            "pretrained"
        ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth "
        config["vit"] = "base"
        config["batch_size_train"] = 32
        config["batch_size_test"] = 16
        config["vit_grad_ckpt"] = True
        config["vit_ckpt_layer"] = 4
        config["init_lr"] = 1e-5
    elif model == "large":
        config[
            "pretrained"
        ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
        config["vit"] = "large"
        config["batch_size_train"] = 16
        config["batch_size_test"] = 32
        config["vit_grad_ckpt"] = True
        config["vit_ckpt_layer"] = 12
        config["init_lr"] = 5e-6

    config["image_size"] = 384
    config["queue_size"] = 57600
    config["alpha"] = 0.4
    config["k_test"] = 256
    config["negative_all_rank"] = True

    return config


@torch.no_grad()
def main(args):
    dataset = VideoDataset(
        video_dir=args.video_dir,
        video_pths=args.videos_txt_path,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=args.num_workers,
    )

    print("Creating model")
    config = get_blip_config(args.model_type)
    model = blip_cir(
        pretrained=config["pretrained"],
        image_size=config["image_size"],
        vit=config["vit"],
        vit_grad_ckpt=config["vit_grad_ckpt"],
        vit_ckpt_layer=config["vit_ckpt_layer"],
        queue_size=config["queue_size"],
        negative_all_rank=config["negative_all_rank"],
    )

    model = model.to(device)
    model.eval()

    if args.save_mean:
        mean_dir = args.save_dir / f"blip-vid-embs-{args.model_type}-mean"
        mean_dir.mkdir(exist_ok=True)
    if args.save_middle:
        middle_dir = args.save_dir / f"blip-vid-embs-{args.model_type}-middle"
        middle_dir.mkdir(exist_ok=True)
    if args.save_all:
        all_dir = args.save_dir / f"blip-vid-embs-{args.model_type}-all"
        all_dir.mkdir(exist_ok=True)

    for video_ids, f_idxs, frames in tqdm(loader):
        frames = frames.to(device)
        bs, nf, c, h, w = frames.shape
        frames = frames.view(bs * nf, c, h, w)
        frm_embs = model.visual_encoder(frames)
        frm_feats = F.normalize(model.vision_proj(frm_embs[:, 0, :]), dim=-1).cpu()
        frm_feats = frm_feats.view(bs, nf, -1)

        for video_id, f_idx, frm_feat in zip(video_ids, f_idxs, frm_feats):
            # remove the features with f_idx=-1
            frm_feat = frm_feat[f_idx > -1]
            f_idx = f_idx[f_idx > -1]
            if len(f_idx) == 0:
                continue
            if args.save_mean:
                mean_feat = frm_feat.mean(dim=0)
                mean_pth = mean_dir / f"{video_id}.pth"
                mean_pth.parent.mkdir(exist_ok=True)
                torch.save(mean_feat, mean_pth)
            if args.save_middle:
                middle_feat = frm_feat[len(frm_feat) // 2]
                middle_pth = middle_dir / f"{video_id}.pth"
                middle_pth.parent.mkdir(exist_ok=True)
                torch.save(middle_feat, middle_pth)
            if args.save_all:
                all_pth = all_dir / f"{video_id}.pth"
                all_pth.parent.mkdir(exist_ok=True)
                torch.save(frm_feat, all_pth)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("video_dir", type=Path)
    parser.add_argument("save_dir", type=Path)
    parser.add_argument("videos_txt_path", type=Path)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument(
        "--model_type", type=str, default="large", choices=["base", "large"]
    )
    parser.add_argument("--num_shards", type=int, default=1)
    parser.add_argument("--shard_id", type=int, default=0)
    parser.add_argument("--save_middle", action="store_true")
    parser.add_argument("--save_mean", action="store_true")
    parser.add_argument("--save_all", action="store_true")

    args = parser.parse_args()

    args.save_dir.mkdir(exist_ok=True)

    assert args.video_dir.exists(), f"{args.video_dir} does not exist"
    assert args.videos_txt_path.exists(), f"{args.videos_txt_path} does not exist"

    # assert at least one save is activated
    assert (
        args.save_middle or args.save_mean or args.save_all
    ), "At least one save option must be activated"

    main(args)
