from argparse import ArgumentParser
from pathlib import Path
from pyserini.search.lucene import LuceneSearcher
import json
import os
import subprocess
from tqdm.auto import tqdm
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


def main(instances_files, documents_dir, indexes_dir):
    documents_dir = Path(documents_dir).resolve().absolute()
    indexes_dir = Path(indexes_dir).resolve().absolute()
    instances = list()
    for instances_file in instances_files:
        instances += [json.loads(line) for line in open(instances_file)]
    repo_commits = {
        (instance["repo"], instance["base_commit"]) for instance in instances
    }
    failed = []
    python = subprocess.run("which python", shell=True, capture_output=True)
    python = python.stdout.decode("utf-8").strip()
    for repo, commit in tqdm(repo_commits, desc="Generating indexes"):
        try:
            documents_path = Path(
                documents_dir, repo.replace("/", "__"), commit
            )
            index_path = Path(indexes_dir, repo.replace("/", "__"), commit)
            if not index_path.exists():
                index_path.mkdir(parents=True)
            if not (documents_path / "documents.jsonl").exists():
                raise Exception(f"Documents path {documents_path}/documents.jsonl does not exist")
            if index_path.exists() and any(index_path.iterdir()):
                try:  # this is faster than creating a new index
                    LuceneSearcher(str(index_path))
                    logger.info(f"Skipping {repo} {commit} because index exists")
                    continue
                except:
                    # delete the index and recreate it
                    logger.info(f"Deleting {index_path} because it is invalid")
                    subprocess.run(f"rm -rf {index_path}", shell=True, check=True)
                    index_path.mkdir()
            output = subprocess.run(
                [
                    python,
                    "-m",
                    "pyserini.index",
                    "--collection",
                    "JsonCollection",
                    "--generator",
                    "DefaultLuceneDocumentGenerator",
                    "--threads",
                    "1",
                    "--input",
                    documents_path,
                    "--index",
                    index_path,
                    "--storePositions",
                    "--storeDocvectors",
                    "--storeRaw",
                ],
                check=True,
                capture_output=True,
            )
        except Exception as e:
            logger.error(f"Failed to process {repo} {commit}")
            logger.error(e)
            failed.append((repo, commit))
    logger.info(f"Failed to process {len(failed)} repos")
    logger.info(failed)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--instances_files", type=str, nargs='+', required=True, help="File containing instances"
    )
    parser.add_argument(
        "--documents_dir",
        type=str,
        help="Directory where retrieval data is stored",
        required=True,
    )
    parser.add_argument(
        "--indexes_dir",
        type=str,
        help="Directory where indexes are stored",
        required=True,
    )
    args = parser.parse_args()
    main(**vars(args))
