from pathlib import Path

import pandas as pd
import torch
from tqdm import tqdm


def main(num_shards=0, shard_id=1):
    # df = pd.read_csv("annotation/webvid2m.1c6M-full.csv")
    df = pd.read_csv("annotation/WebVid8M.2c4k-manual.csv")
    # videos_dir = Path("datasets/WebVid/2M/blip-vid-embs-large-all/")
    videos_dir = Path("datasets/WebVid/8M/blip-vid-embs-large-all/")
    # txt2_embs = torch.load(videos_dir / "webvid2m.1c6M-full.pth")
    txt2_embs = torch.load(videos_dir / "WebVid8M.2c4k-manual.pth")
    txt2emb = dict(zip(txt2_embs["texts"], txt2_embs["feats"]))

    df = df.iloc[shard_id::num_shards]

    scores = []
    for row in tqdm(df.itertuples(), total=len(df)):
        txt2 = row.txt2
        pth2 = videos_dir / f"{row.pth2}.pth"
        if pth2.exists():
            txt_emb = txt2emb[txt2]
            vid_emb = torch.load(pth2)

            row_scores = torch.einsum("fe,e->f", vid_emb, txt_emb).tolist()
            row_scores = [round(score, 4) for score in row_scores]
            scores.append(row_scores)

        else:
            scores.append([])

    df["scores"] = scores
    df.to_csv(
        f"annotation/WebVid8M.2c4k-manual-scores_{shard_id}-{num_shards}.csv",
        index=False,
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="")
    parser.add_argument("num_shards", type=int)
    parser.add_argument("shard_id", type=int)
    args = parser.parse_args()

    main(args.num_shards, args.shard_id)
