import argparse
import ast
import json
import os
import subprocess
import uuid

import pandas as pd
from tqdm import tqdm

# when encounter KeyError, use this class to return a default value
class DefaultDict(dict):
    """A simple defaultdict implementation."""
    def __init__(self, dict, default_value):
        super().__init__(dict)
        self.default_value = default_value

    def __missing__(self, key):
        return self.default_value

repo_to_top_folder = {
    "django/django": "django",
    "sphinx-doc/sphinx": "sphinx",
    "scikit-learn/scikit-learn": "scikit-learn",
    "sympy/sympy": "sympy",
    "pytest-dev/pytest": "pytest",
    "matplotlib/matplotlib": "matplotlib",
    "astropy/astropy": "astropy",
    "pydata/xarray": "xarray",
    "mwaskom/seaborn": "seaborn",
    "psf/requests": "requests",
    "pylint-dev/pylint": "pylint",
    "pallets/flask": "flask",
}

repo_to_top_folder = DefaultDict(repo_to_top_folder, "default")

def checkout_commit(repo_path, commit_id):
    """Checkout the specified commit in the given local git repository.
    :param repo_path: Path to the local git repository
    :param commit_id: Commit ID to checkout
    :return: None
    """
    try:
        # Change directory to the provided repository path and checkout the specified commit
        print(f"Checking out commit {commit_id} in repository at {repo_path}...")
        subprocess.run(["git", "-C", repo_path, "checkout", commit_id], check=True)
        print("Commit checked out successfully.")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred while running git command: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

def clone_repo(repo_name, repo_playground):
    try:

        print(
            f"Cloning repository from https://github.com/{repo_name}.git to {repo_playground}/{repo_to_top_folder[repo_name]}..."
        )
        subprocess.run(
            [
                "git",
                "clone",
                f"https://github.com/{repo_name}.git",
                f"{repo_playground}/{repo_to_top_folder[repo_name]}",
            ],
            check=True,
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred while running git command: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

def get_project_structure_from_scratch(
    repo_name, commit_id, instance_id, repo_playground
):

    # Generate a temperary folder and add uuid to avoid collision
    repo_playground = os.path.join(repo_playground, str(uuid.uuid4()))

    # assert playground doesn't exist
    assert not os.path.exists(repo_playground), f"{repo_playground} already exists"

    # create playground
    os.makedirs(repo_playground)

    clone_repo(repo_name, repo_playground)
    checkout_commit(f"{repo_playground}/{repo_to_top_folder[repo_name]}", commit_id)
    structure = create_structure(f"{repo_playground}/{repo_to_top_folder[repo_name]}")
    # clean up
    subprocess.run(
        ["rm", "-rf", f"{repo_playground}/{repo_to_top_folder[repo_name]}"], check=True
    )
    d = {
        "repo": repo_name,
        "base_commit": commit_id,
        "structure": structure,
        "instance_id": instance_id,
    }
    return d

def get_before_after_structure_from_scratch(
    repo_name, commit_id, patch, instance_id, repo_playground
):
    # Generate a temperary folder and add uuid to avoid collision
    repo_playground = os.path.join(repo_playground, str(uuid.uuid4()))

    # assert playground doesn't exist
    assert not os.path.exists(repo_playground), f"{repo_playground} already exists"

    # create playground
    os.makedirs(repo_playground)

    try:
        clone_repo(repo_name, repo_playground)
        checkout_commit(f"{repo_playground}/{repo_to_top_folder[repo_name]}", commit_id)
        structure = create_structure(f"{repo_playground}/{repo_to_top_folder[repo_name]}")
        # apply patch and get the after structure
        patch_file_path = os.path.join(repo_playground, "patch.diff")
        with open(patch_file_path, "w") as patch_file:
            patch_file.write(patch)
        try:
            subprocess.run(
                [
                    "git",
                    "-C",
                    f"{repo_playground}/{repo_to_top_folder[repo_name]}",
                    "apply",
                    "--whitespace=nowarn",
                    "../patch.diff",
                ],
                check=True,
            )
        except subprocess.CalledProcessError as e:
            print(f"An error occurred while applying the patch for instance {instance_id}: {e}")
            raise e
        # get the after structure
        after_structure = create_structure(
            f"{repo_playground}/{repo_to_top_folder[repo_name]}"
        )
    except Exception as e:
        print(e)
    finally:
        # clean up
        subprocess.run(
            ["rm", "-rf", f"{repo_playground}/{repo_to_top_folder[repo_name]}"], check=True
        )
    d = {
        "repo": repo_name,
        "base_commit": commit_id,
        "structure": structure,
        "after_structure": after_structure,
        "patch": patch,
        "instance_id": instance_id,
    }
    return d

def get_completion_structure_from_scratch(
    repo_name, commit_id, patch, mask_patch, instance_id, repo_playground
):
    # Generate a temperary folder and add uuid to avoid collision
    repo_playground = os.path.join(repo_playground, str(uuid.uuid4()))

    # assert playground doesn't exist
    assert not os.path.exists(repo_playground), f"{repo_playground} already exists"

    # create playground
    os.makedirs(repo_playground)

    clone_repo(repo_name, repo_playground)
    checkout_commit(f"{repo_playground}/{repo_to_top_folder[repo_name]}", commit_id)
    # apply patch and get the after structure
    patch_file_path = os.path.join(repo_playground, "patch.diff")
    with open(patch_file_path, "w") as patch_file:
        patch_file.write(patch)
    try:
        subprocess.run(
            [
                "git",
                "-C",
                f"{repo_playground}/{repo_to_top_folder[repo_name]}",
                "apply",
                "--whitespace=nowarn",
                "../patch.diff",
            ],
            check=True,
        )
    except subprocess.CalledProcessError as e:
        print(f"An error occurred while applying the patch for instance {instance_id}: {e}")
        raise e
    # get the after structure
    after_structure = create_structure(
        f"{repo_playground}/{repo_to_top_folder[repo_name]}"
    )
    # apply mask patch and get the before structure
    mask_patch_file_path = os.path.join(repo_playground, "mask_patch.diff")
    with open(mask_patch_file_path, "w") as mask_patch_file:
        mask_patch_file.write(mask_patch)
    try:
        subprocess.run(
            [
                "git",
                "-C",
                f"{repo_playground}/{repo_to_top_folder[repo_name]}",
                "apply",
                "--whitespace=nowarn",
                "../mask_patch.diff",
            ],
            check=True,
        )
    except subprocess.CalledProcessError as e:
        print(f"An error occurred while applying the mask patch for instance {instance_id}: {e}")
        raise e
    # get the before structure
    structure = create_structure(f"{repo_playground}/{repo_to_top_folder[repo_name]}")
    # clean up
    subprocess.run(
        ["rm", "-rf", f"{repo_playground}/{repo_to_top_folder[repo_name]}"], check=True
    )
    d = {
        "repo": repo_name,
        "base_commit": commit_id,
        "structure": structure,
        "after_structure": after_structure,
        "patch": patch,
        "instance_id": instance_id,
    }
    return d

def parse_python_file(file_path, file_content=None):
    """Parse a Python file to extract class and function definitions with their line numbers.
    :param file_path: Path to the Python file.
    :return: Class names, function names, global blocks, and file contents
    """
    if file_content is None:
        try:
            with open(file_path, "r") as file:
                file_content = file.read()
                parsed_data = ast.parse(file_content)
        except Exception as e:  # Catch all types of exceptions
            print(f"Error in file {file_path}: {e}")
            return [], [], [], "'''Error while loading file'''"
    else:
        try:
            parsed_data = ast.parse(file_content)
        except Exception as e:  # Catch all types of exceptions
            print(f"Error in file {file_path}: {e}")
            return [], [], [], file_content.splitlines()

    class_info = []
    function_names = []
    global_blocks = []
    class_methods = set()

    for node in ast.walk(parsed_data):
        if isinstance(node, ast.ClassDef):
            methods = []
            for n in node.body:
                if isinstance(n, ast.FunctionDef):
                    methods.append(
                        {
                            "name": n.name,
                            "start_line": n.lineno,
                            "end_line": n.end_lineno,
                            "text": file_content.splitlines()[
                                n.lineno - 1 : n.end_lineno
                            ],
                        }
                    )
                    class_methods.add(n.name)
            class_info.append(
                {
                    "name": node.name,
                    "start_line": node.lineno,
                    "end_line": node.end_lineno,
                    "text": file_content.splitlines()[
                        node.lineno - 1 : node.end_lineno
                    ],
                    "methods": methods,
                }
            )
        elif isinstance(node, ast.FunctionDef) and not isinstance(
            node, ast.AsyncFunctionDef
        ):
            if node.name not in class_methods:
                function_names.append(
                    {
                        "name": node.name,
                        "start_line": node.lineno,
                        "end_line": node.end_lineno,
                        "text": file_content.splitlines()[
                            node.lineno - 1 : node.end_lineno
                        ],
                    }
                )

    # iter the collected class_info and function_names to get the line where the class or function is covered
    # then the remained lines are global blocks
    covered_lines = set()
    # 将所有在class或function内部的row号加入set
    for item in class_info + function_names:
        covered_lines.update(range(item["start_line"], item["end_line"] + 1))

    global_blocks = []
    accumulate_consecutive_block = ""
    last_block_start = 0
    last_global_line = 0

    for i, line in enumerate(file_content.splitlines(), start=1):
        if i in covered_lines:
            if accumulate_consecutive_block:
                global_blocks.append(
                    {
                        "name": f"global_block_{last_block_start}-{last_global_line}",
                        "start_line": last_block_start,
                        "end_line": last_global_line,
                        "text": accumulate_consecutive_block.splitlines(),
                    }
                )
                accumulate_consecutive_block = ""
            continue  
        else:
            if not accumulate_consecutive_block:
                last_block_start = i
            accumulate_consecutive_block += line + "\n"
            last_global_line = i

    if accumulate_consecutive_block:
        global_blocks.append(
            {
                "name": f"global_block_{last_block_start}-{last_global_line}",
                "start_line": last_block_start,
                "end_line": last_global_line,
                "text": accumulate_consecutive_block.splitlines(),
            }
        )

    return class_info, function_names, global_blocks, file_content.splitlines()

def create_structure(directory_path):
    """Create the structure of the repository directory by parsing Python files.
    :param directory_path: Path to the repository directory.
    :return: A dictionary representing the structure.
    """
    structure = {}

    for root, _, files in os.walk(directory_path):
        repo_name = os.path.basename(directory_path)
        relative_root = os.path.relpath(root, directory_path)
        ## repo_name will not be added to the first level files
        # if relative_root == ".":
        #     relative_root = repo_name
        curr_struct = structure
        if relative_root.startswith('./'):
            relative_root = relative_root[2:]
        if relative_root != ".":
            for part in relative_root.split(os.sep):
                if part not in curr_struct:
                    curr_struct[part] = {}
                curr_struct = curr_struct[part]
        for file_name in files:
            if file_name.endswith(".py"):
                if root == ".":
                    file_path = file_name
                else:
                    file_path = os.path.join(root, file_name)
                class_info, function_names, global_blocks, file_lines = parse_python_file(file_path)
                curr_struct[file_name] = {
                    "classes": class_info,
                    "functions": function_names,
                    "globals": global_blocks,
                    "text": file_lines,
                }
            else:
                curr_struct[file_name] = {}

    return structure
