from pathlib import Path
import pandas as pd
from tqdm import tqdm
import pandas as pd
from tqdm import tqdm


def filter_duplicates(df, return_ids=False):
    df["txt_ids"] = df["txt1"].str.cat(df["txt2"], sep="&").sort_values()
    df = df.drop_duplicates(subset=["txt_ids"])
    df.reset_index(drop=True, inplace=True)
    # remove txt_ids column
    if not return_ids:
        df = df.drop(columns=["txt_ids"])
    return df


def remove_duplicates(df, txt_ids, return_ids=False):
    # Remove nan values from txt1 and txt2
    df = df[~df["txt1"].isna()]
    df = df[~df["txt2"].isna()]
    # add txt_ids column if not already present
    if "txt_ids" not in df.columns:
        df["txt_ids"] = df["txt1"].str.cat(df["txt2"], sep="&").sort_values()
    df = df[~df["txt_ids"].isin(txt_ids)]
    df.reset_index(drop=True, inplace=True)
    if not return_ids:
        df = df.drop(columns=["txt_ids"])
    return df


def main(dfs_dir):
    dfs_dir = Path(dfs_dir)
    dfs_out_dir = dfs_dir.parent / (dfs_dir.name + "_filtered-duplicates")
    dfs_out_dir.mkdir(exist_ok=True)
    paths = list(dfs_dir.glob("*.csv"))
    paths.sort()

    txt_ids = set()
    dfs = []
    for path in tqdm(paths):
        df = pd.read_csv(path)
        if df.empty:
            print(f"Empty df: {path}")
            continue
        # df = filter_duplicates(df, return_ids=True)
        df = remove_duplicates(df, txt_ids, return_ids=True)
        txt_ids = txt_ids.union(set(df["txt_ids"]))
        df = df.drop(columns=["txt_ids"])
        dfs.append(df)

    df = pd.concat(dfs, ignore_index=True)
    df[["txt1", "txt2"]] = pd.DataFrame(
        df[["txt1", "txt2"]].apply(lambda x: sorted(x), axis=1).tolist(),
        index=df.index,
    )
    df = df.sort_values(by=["txt1"], ignore_index=True)
    # df_out = dfs_out_dir / "similar_sentences.csv"
    # df.to_csv(df_out, index=False)

    # split dataset in 100 chunks
    split_size = df.shape[0] // 100
    df["chunk"] = df.index // split_size

    for chunk in df["chunk"].unique():
        df_chunk = df[df["chunk"] == chunk]
        df_chunk = df_chunk.drop(columns=["chunk"])
        df_chunk.to_csv(
            dfs_out_dir / f"similar_sentences_{chunk}.csv",
            index=False,
        )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("dfs_dir", type=Path)
    args = parser.parse_args()

    main(args.dfs_dir)
