import json
import os
import ast
import jedi
from git import Repo
from pathlib import Path
from utils import ContextManager, is_test
from tqdm.auto import tqdm
from argparse import ArgumentParser
from tempfile import TemporaryDirectory
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


def file_name_and_contents(filename, relative_path):
    text = relative_path + "\n"
    with open(filename, "r") as f:
        text += f.read()
    return text


def file_name_and_documentation(filename, relative_path):
    text = relative_path + "\n"
    try:
        with open(filename, "r") as f:
            node = ast.parse(f.read())

        # Get module docstring
        data = ast.get_docstring(node)
        if data:
            text += f"{data}"

        # Walk through all nodes in the AST
        for child_node in ast.walk(node):
            if isinstance(
                child_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
            ):
                data = ast.get_docstring(child_node)
                if data:
                    text += f"\n\n{child_node.name}\n{data}"
    except Exception as e:
        logger.error(e)
        logger.error(f"Failed to parse file {str(filename)}. Using simple filecontent.")
        with open(filename, "r") as f:
            text += f.read()
    return text


def file_name_and_docs_jedi(filename, relative_path):
    text = relative_path + "\n"
    with open(filename, "r") as f:
        source_code = f.read()
    try:
        script = jedi.Script(source_code, path=filename)
        module = script.get_context()
        docstring = module.docstring()
        text += f"{module.full_name}\n"
        if docstring:
            text += f"{docstring}\n\n"
        abspath = Path(filename).absolute()
        names = [
            name
            for name in script.get_names(
                all_scopes=True, definitions=True, references=False
            )
            if not name.in_builtin_module()
        ]
        for name in names:
            try:
                origin = name.goto(follow_imports=True)[0]
                if origin.module_name != module.full_name:
                    continue
                if name.parent().full_name != module.full_name:
                    if name.type in {"statement", "param"}:
                        continue
                full_name = name.full_name
                text += f"{full_name}\n"
                docstring = name.docstring()
                if docstring:
                    text += f"{docstring}\n\n"
            except:
                continue
    except Exception as e:
        logger.error(e)
        logger.error(f"Failed to parse file {str(filename)}. Using simple filecontent.")
        text = f"{relative_path}\n{source_code}"
        return text
    return text


DOCUMENT_ENCODING_FUNCTIONS = {
    "file_name_and_contents": file_name_and_contents,
    "file_name_and_documentation": file_name_and_documentation,
    "file_name_and_docs_jedi": file_name_and_docs_jedi,
}


def main(instances_files, documents_dir, document_encoding_style, token=None):
    document_encoding_func = DOCUMENT_ENCODING_FUNCTIONS[document_encoding_style]
    documents_dir = Path(documents_dir).resolve().absolute()
    # root_dir = Path(root_dir).resolve().absolute()
    if token is None:
        token = os.environ.get("GITHUB_TOKEN", "git")
    instances = list()
    for instances_file in instances_files:
        instances += [json.loads(line) for line in open(instances_file)]
    repo_commits = {
        (instance["repo"], instance["base_commit"]) for instance in instances
    }
    failed = []
    tmp_root = '/scratch' if os.path.exists('/scratch') else '/tmp'
    with TemporaryDirectory(dir=tmp_root) as root_dir:
        for repo, commit in tqdm(repo_commits, desc="Generating documents files"):
            try:
                repo_dir = os.path.join(root_dir, repo.replace("/", "__"))
                if not os.path.exists(repo_dir):
                    repo_url = (
                        f"https://{token}@github.com/swe-bench/"
                        + repo.replace("/", "__")
                        + ".git"
                    )
                    logger.info(f"Cloning {repo}")
                    Repo.clone_from(repo_url, repo_dir)
                output_dir = Path(documents_dir, repo.replace("/", "__"), commit)
                documents_file = output_dir / "documents.jsonl"
                progress_file = output_dir / 'incomplete.item'
                read_documents = set()
                if documents_file.exists():
                    if not progress_file.exists():
                        continue
                    with open(documents_file, 'r') as f:
                        for line in f:
                            doc = json.loads(line)
                            read_documents.add(doc['id'])
                with ContextManager(repo_dir, commit) as cm:
                    if not output_dir.exists():
                        output_dir.mkdir(parents=True)
                    start_ix = len(cm.repo_path) + 1
                    all_files = Path(cm.repo_path).rglob("*.py")
                    python_files = list()
                    for filename in all_files:
                        if is_test(filename.as_posix()):
                            continue
                        rel_filename = str(filename)[start_ix:]
                        if rel_filename in read_documents:
                            continue
                        python_files.append((filename, rel_filename))
                    open(progress_file, 'w').close()
                    with open(documents_file, "+a") as f:
                        for filename, relative_path in tqdm(
                            python_files, leave=False, desc=f"Processing {repo}-{commit}"
                        ):
                            content = document_encoding_func(filename, relative_path)
                            print(
                                json.dumps({"id": relative_path, "contents": content}),
                                file=f,
                                flush=True,
                            )
                    progress_file.unlink()
            except Exception as e:
                logger.error(f"Failed to process {repo} {commit}")
                logger.error(e)
                failed.append((repo, commit))
    logger.info(f"Failed to process {len(failed)} repos")
    logger.info(failed)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--instances_files",
        type=str,
        nargs="+",
        required=True,
        help="File containing instances",
    )
    parser.add_argument(
        "--document_encoding_style",
        type=str,
        default="file_name_and_contents",
        choices=DOCUMENT_ENCODING_FUNCTIONS.keys(),
        help="Preprocessing function for encoding documents.",
    )
    parser.add_argument(
        "--documents_dir",
        type=str,
        help="Directory where retrieval data is stored",
        required=True,
    )
    parser.add_argument(
        "--token",
        type=str,
        help="Github token to use for cloning repositories",
        required=False,
    )
    args = parser.parse_args()
    main(**vars(args))
