import sys
from pathlib import Path

import pandas as pd
from s2_filter_duplicates import filter_duplicates
from tqdm import tqdm

sys.path.append("..")
from utils_retrieval import Index, TxtEmbeddings


def find_similar_sentences(sentence, candidates):
    similar_results = []
    sentence = sentence.lower().replace(".", "")
    for candidate in candidates:
        candidate_sentence = candidate.lower().replace(".", "")
        sentence1_words = sentence.split()
        sentence2_words = candidate_sentence.split()
        if len(sentence1_words) != len(sentence2_words):
            continue
        num_different_words = 0
        for w1, w2 in zip(sentence1_words, sentence2_words):
            if w1 != w2:
                num_different_words += 1
        if num_different_words == 1:
            similar_results.append(candidate)
    return similar_results


def main(num_shards: int, shard_id: int, num: int = 1000):
    total_caps = len(txts_embs.captions)
    caps_per_shard = total_caps // num_shards
    start_idx = shard_id * caps_per_shard
    end_idx = start_idx + caps_per_shard
    if shard_id == num_shards - 1:
        end_idx = total_caps
    table = []
    for cap_txt in tqdm(txts_embs.captions[start_idx:end_idx]):
        cap_emb = txts_embs[cap_txt]
        if cap_emb is None:
            print(cap_txt)
            continue
        ret_idxs, _ = txts_idx.get_idxs(cap_emb, num)
        ret_txts = [txts_idx.idx2data(idx) for idx in ret_idxs]

        # Do not consider captions with only one word
        n_words = len(cap_txt.split())
        if n_words <= 1:
            continue

        sim_txts = find_similar_sentences(cap_txt, ret_txts)
        for txt2 in sim_txts:
            if cap_txt in txts_embs and txt2 in txts_embs:
                sim = txts_embs[cap_txt] @ txts_embs[txt2]
            else:
                sim = 0
            table.append((cap_txt, txt2, sim))

    df = pd.DataFrame(table, columns=["txt1", "txt2", "sim_txt"])
    out_dir = dataset_dir / "similar-sentences/txt"
    out_dir.mkdir(exist_ok=True)
    out_pth = str(out_dir / f"similar_sentences_shard_{shard_id}.csv")
    df = filter_duplicates(df)
    if len(df) > 0:
        df.to_csv(out_pth, index=False)


if __name__ == "__main__":
    dataset = "conceptual_captions"
    dataset = "WebVid/2M"
    dataset = "WebVid/8M"
    dataset_dir = Path(f"./{dataset}")
    assert dataset_dir.exists()
    clip_dir = dataset_dir / f"clip-retrieval/cr-vitl14"
    txts_embs = TxtEmbeddings(clip_dir / "txts-embeddings")
    txts_idx = Index(clip_dir / "txts-index")

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("num_shards", type=int)
    parser.add_argument(
        "shard_ids",
        type=str,
        help="comma-separated shard ids or a range of ids (e.g. 1,3,5-10)",
    )
    args = parser.parse_args()

    shard_ids = []
    for id_range in args.shard_ids.split(","):
        if "-" in id_range:
            start, end = id_range.split("-")
            shard_ids += list(range(int(start), int(end) + 1))
        else:
            shard_ids.append(int(id_range))

    for shard_id in shard_ids:
        main(args.num_shards, shard_id)
