import os
import json
import shutil
from urllib.parse import urlparse
import pygit2
import re
from tree_sitter import Parser, Language
import importlib
import tiktoken
import difflib
import concurrent.futures
from func_timeout import func_set_timeout

import tempfile
import pathlib

LANGUAGE_EXTENSIONS = {
    "ada": ".adb",
    "agda": ".agda",
    "apex": ".cls",
    "apexcode": ".cls",
    "json": ".json",
    "bash": ".sh",
    "beancount": ".beancount",
    "capnp": ".capnp",
    "c": ".c",
    "cpp": ".cpp",
    "c-sharp": ".cs",
    "cel": ".cel",
    "clojure": ".clj",
    "cmake": ".cmake",
    "cobol": ".cob",
    "commonlisp": ".lisp",
    "css": ".css",
    "cuda": ".cu",
    "dart": ".dart",
    "d": ".d",
    "dockerfile": "Dockerfile",
    "dot": ".dot",
    "elixir": ".ex",
    "elm": ".elm",
    "emacs-lisp": ".el",
    "eno": ".eno",
    "erb": ".erb",
    "erlang": ".erl",
    "fennel": ".fnl",
    "fish": ".fish",
    "formula": ".formula",
    "fortran": ".f90",
    "gitattributes": ".gitattributes",
    "gitignore": ".gitignore",
    "gleam": ".gleam",
    "glsl": ".glsl",
    "go": ".go",
    "go-mod": "go.mod",
    "go-work": "go.work",
    "graphql": ".graphql",
    "hack": ".hack",
    "haskell": ".hs",
    "hcl": ".hcl",
    "html": ".html",
    "ispc": ".ispc",
    "java": ".java",
    "javascript": ".js",
    "jq": ".jq",
    "json": ".json",
    "json5": ".json5",
    "julia": ".jl",
    "just": ".just",
    "kotlin": ".kt",
    "lalrpop": ".lalrpop",
    "latex": ".tex",
    "lean": ".lean",
    "llvm": ".ll",
    "llvm-machineir": ".mir",
    "llvm-mlir": ".mlir",
    "llvm-tablegen": ".td",
    "lua": ".lua",
    "magik": ".magik",
    "makefile": "Makefile",
    "markdown": ".md",
    "meson": ".meson",
    "motorola-68000-assembly": ".asm",
    "nginx": ".nginx",
    "nim": ".nim",
    "nix": ".nix",
    "noir": ".nr",
    "objective-c": ".m",
    "ocaml": ".ml",
    "odin": ".odin",
    "ohm": ".ohm",
    "org": ".org",
    "p4": ".p4",
    "pascal": ".pas",
    "perl": ".pl",
    "perl-pod": ".pod",
    "php": ".php",
    "portable-game-notation": ".pgn",
    "powershell": ".ps1",
    "protocol-buffers": ".proto",
    "python": ".py",
    "qml": ".qml",
    "quakec": ".qc",
    "racket": ".rkt",
    "rasi": ".rasi",
    "re2c": ".re2c",
    "regex": ".re",
    "rego": ".rego",
    "restructuredtext": ".rst",
    "r": ".r",
    "robot": ".robot",
    "ruby": ".rb",
    "rust": ".rs",
    "scala": ".scala",
    "scheme": ".scm",
    "scss": ".scss",
    "s-expressions": ".sexp",
    "smali": ".smali",
    "sourcepawn": ".sp",
    "sparql": ".rq",
    "sql": ".sql",
    "postgresql": ".pgsql",
    "sqlite": ".sqlite",
    "ssh": ".ssh",
    "supercollider": ".scd",
    "svelte": ".svelte",
    "swift": ".swift",
    "systemrdl": ".rdl",
    "tact": ".tact",
    "thrift": ".thrift",
    "todo": ".todo",
    "toml": ".toml",
    "tree-sitter-query": ".tsq",
    "turtle": ".ttl",
    "twig": ".twig",
    "typescript": ".ts",
    "ungrammar": ".ungrammar",
    "usd": ".usd",
    "verilog": ".v",
    "vhdl": ".vhd",
    "vue": ".vue",
    "wasm": ".wasm",
    "wdl": ".wdl",
    "wgsl": ".wgsl",
    "yaml": ".yaml",
    "yang": ".yang",
    "yuck": ".yuck",
    "zig": ".zig",
}


class GitRepository:
    def __init__(self, repo_url, clone_dir):
        self.repo_url = self.construct_repo_url(repo_url)
        self.clone_dir = clone_dir
        self.repo = self.clone_repo(self.repo_url, self.clone_dir)
        self.commits_info = None  
        self.language_modules = None
        self.enc = None

    def construct_repo_url(self, repo_url):
        if not repo_url.startswith("https://github.com/"):
            repo_url = "https://github.com/" + repo_url.strip("/").replace("_", "/")
        return repo_url

    def clone_repo(self, repo_url, clone_dir=None):
        parsed_url = urlparse(repo_url)
        path_parts = parsed_url.path.strip("/").split("/")
        if len(path_parts) == 2:
            org, repo_name = path_parts
            repo_dir = f"{org}_{repo_name.split('.')[0]}"
        else:
            raise ValueError(
                "Invalid repo_url format. Expected '{org}/{repo}' or 'https://github.com/{org}/{repo}'"
            )
        print(f"org: {org}, repo_name: {repo_name}, repo_url: {repo_url}")
        if clone_dir is not None:
            repo_dir = os.path.join(clone_dir, repo_dir)

        try:
            if not os.path.exists(repo_dir):
                os.makedirs(repo_dir)

                
                def clone_repo_func(repo_url, repo_dir):
                    return pygit2.clone_repository(repo_url, repo_dir)

                def clone_repo_with_timeout(repo_url, repo_dir, timeout):
                    with concurrent.futures.ThreadPoolExecutor() as executor:
                        future = executor.submit(clone_repo_func, repo_url, repo_dir)
                        try:
                            result = future.result(timeout=timeout)
                            return result
                        except Exception as e:
                            raise e

                
                repo = clone_repo_with_timeout(repo_url, repo_dir, timeout=50)
            else:
                repo = pygit2.Repository(repo_dir)
                
        except Exception as e:
            print(f"Error cloning repository: {str(e)}")
            repo = None
        return repo

    def get_partial_content(self, blob, max_length=1024):
        data = blob.data
        if len(data) > max_length:
            
            half_length = max_length // 2
            start = data[:half_length].decode("utf-8", errors="ignore")
            end = data[-half_length:].decode("utf-8", errors="ignore")
            return f"{start}...[CONTENT TRUNCATED]...{end}"
        else:
            return data.decode("utf-8", errors="ignore")

    def get_file_contents_for_commit(
        self, commit_sha, max_file_size=1024 * 1024
    ):  
        commit = self.repo.get(commit_sha)
        tree = commit.tree

        def get_tree_recursive(tree, path_prefix=""):
            files_info = []

            for entry in tree:
                item_path = f"{path_prefix}{entry.name}"
                item_type = entry.type_str
                item_id = entry.id
                
                if item_type == "blob":
                    blob = self.repo.get(item_id)
                    try:
                        content = blob.data.decode("utf-8")
                        content = re.sub(
                            r"data:image\/[a-zA-Z]+;base64,[a-zA-Z0-9+/=]+",
                            "data:image/<binary content>",
                            content,
                        )
                        if blob.size > max_file_size:
                            content = self.get_partial_content(
                                blob, max_length=1024
                            )  
                            content = f"File too large to display fully. Size: {blob.size} bytes. The following are the beginning and end contents of the file:\n{content}"
                    except UnicodeDecodeError:
                        content = "<binary content>"
                    files_info.append(
                        {
                            "path": item_path,
                            "type": str(item_type),
                            "id": str(item_id),
                            "content": content,
                        }
                    )
                elif item_type == "tree":
                    sub_tree = self.repo.get(item_id)
                    sub_tree_files = get_tree_recursive(
                        sub_tree, path_prefix=f"{item_path}/"
                    )
                    files_info.extend(sub_tree_files)

            return files_info

        return get_tree_recursive(tree)

    def get_file_contents_before_commit(self, commit_sha):
        commit = self.repo.get(commit_sha)
        if len(commit.parents) == 0:
            print(f"Commit {commit_sha} has no parent, it might be the initial commit.")
            return None  

        parent_commit = commit.parents[0]

        return self.get_file_contents_for_commit(parent_commit.id)

    def get_patch_info_for_commit(self, commit_sha):
        commit = self.repo.get(commit_sha)

        if len(commit.parents) == 0:
            print(f"Commit {commit_sha} has no parent.")
            return None

        parent_tree = commit.parents[0].tree
        current_tree = commit.tree
        diff = self.repo.diff(parent_tree, current_tree)

        commit_info = {
            "commit_sha": str(commit.id),
            "author_name": commit.author.name,
            "author_email": commit.author.email,
            "date": commit.commit_time,
            "commit_message": commit.message.strip(),
            "parents": [str(parent.id) for parent in commit.parents],
            "stats": {
                "total": diff.stats.insertions + diff.stats.deletions,
                "additions": diff.stats.insertions,
                "deletions": diff.stats.deletions,
                "files_changed": diff.stats.files_changed,
            },
            "files_diff": [],
        }

        for patch in diff:
            stats = patch.line_stats

            change_info = {
                "filename": patch.delta.new_file.path,
                "status": patch.delta.status_char(),
                "additions": stats[1],
                "deletions": stats[2],
                "changes": stats[1] + stats[2],
                "patch": patch.text if not patch.delta.is_binary else None,
            }
            commit_info["files_diff"].append(change_info)

        return commit_info

    def get_commits_info(self):
        commits = list(self.repo.walk(self.repo.head.target, pygit2.GIT_SORT_TIME))

        
        for commit in commits:
            
            commit_info = self.get_patch_info_for_commit(str(commit.id))
            if commit_info:
                yield commit_info
        
        

    def filter_commits_by_changes(
        self,
        min_files_changed=None,
        max_files_changed=None,
        min_total_changes=None,
        max_total_changes=None,
    ):
        """
        Filter commits based on the number of files changed and the total number of line changes.

        :param min_files_changed: Minimum number of files changed required.
        :param max_files_changed: Maximum number of files changed allowed.
        :param min_total_changes: Minimum number of total changes (additions + deletions) required.
        :param max_total_changes: Maximum number of total changes allowed.
        :return: List of commit information dictionaries that match the criteria.
        """
        if self.commits_info is None:
            self.commits_info = self.get_commits_info()
        all_commits = self.commits_info

        def is_within_range(value, min_val, max_val):
            return (min_val is None or value >= min_val) and (
                max_val is None or value <= max_val
            )

        
        
        
        
        
        for commit in all_commits:
            if is_within_range(
                commit["stats"]["files_changed"], min_files_changed, max_files_changed
            ) and is_within_range(
                commit["stats"]["total"], min_total_changes, max_total_changes
            ):
                yield commit

        

    def simplify_commit_data(self, commit):
        """
        Simplify commit data by extracting only commit_message, and filename and patch from files_diff.

        :param commits: List of commit dictionaries.
        :return: List of simplified commit dictionaries.
        """
        
        
        simplified_commit = {
            
            "commit_message": commit["commit_message"],
            "files_diff": [
                {
                    "filename": file["filename"],
                    "status": file["status"],
                    "patch": file["patch"],
                }
                for file in commit["files_diff"]  
            ],
        }
        

        return simplified_commit

    def simplify_code_data(self, files_data):
        """
        Simplify code data by extracting only the path and content from each file's information.

        :param files_data: List of dictionaries containing file information.
        :return: List of simplified file dictionaries with only path and content.
        """
        simplified_data = []
        for file_info in files_data:
            if "path" in file_info and "content" in file_info:
                simplified_file = {
                    "path": file_info["path"],
                    "content": file_info["content"],
                }
                simplified_data.append(simplified_file)
        return simplified_data

    def generate_patch(self, commit) -> str:  
        
        
        
        

        patch_content = ""
        for file_diff in commit["files_diff"]:
            if file_diff["patch"]:
                patch_content += (
                    f"diff --git a/{file_diff['filename']} b/{file_diff['filename']}\n"
                )
                patch_content += file_diff["patch"]

        return patch_content

    def load_language_modules(self, file_names):
        
        
        def get_language_module(file_name):
            ext = os.path.splitext(file_name)[1]
            if not ext:  
                ext = os.path.basename(file_name)  
            return ext

        required_languages = set(
            get_language_module(file_name) for file_name in file_names
        )
        modules = {}
        for ext in required_languages:
            lang = next((k for k, v in LANGUAGE_EXTENSIONS.items() if v == ext), None)
            if lang:
                module_name = f"tree_sitter_{lang}"
                try:
                    module = importlib.import_module(module_name)
                    modules[ext] = module.language()
                except ModuleNotFoundError:
                    print(f"Module {module_name} not found, skipping.")
            else:
                print(f"No language found for extension {ext}")
        return modules

    def get_language_by_extension(self, file_extension):
        
        if file_extension in self.language_modules:
            return self.language_modules[file_extension]
        else:
            print(f"No language support for {file_extension}, skipping.")
            return None

    def get_changed_functions_in_commit(self, commit_sha):
        commit = self.repo.get(commit_sha)
        
        if len(commit.parents) == 0:
            print(f"Commit {commit} has no parent, it might be the initial commit.")
            return []

        parent_commit = commit.parents[0]
        diff = self.repo.diff(parent_commit, commit)
        changes = []
        self.language_modules = self.load_language_modules(
            [patch.delta.new_file.path for patch in diff]
        )
        for patch in diff:
            if patch.delta.is_binary:
                continue  

            file_path = patch.delta.new_file.path
            file_extension = os.path.splitext(file_path)[1]
            changed_lines, all_lines = self.get_changed_lines_from_patch(patch.hunks)

            language = self.get_language_by_extension(file_extension)
            if not language:
                changes.append(
                    {
                        "file": file_path,
                        "function": None,
                        "content_all": all_lines,
                        "content_change": changed_lines,
                    }
                )
                continue

            parser = Parser()
            parser.language = Language(language)

            try:
                parent_blob = self.repo.get(
                    parent_commit.tree[patch.delta.old_file.path].id
                )
            except KeyError:
                continue
            parent_code = parent_blob.data.decode("utf-8")

            tree = parser.parse(bytes(parent_code, "utf8"))
            functions = self.extract_functions(
                tree.root_node, bytes(parent_code, "utf8")
            )

            remained_changed_lines = changed_lines.copy()
            remained_all_lines = all_lines.copy()

            for func in functions:
                func_start_line, func_end_line, func_name = func[1], func[2], func[3]

                func_changes = []
                func_all = []

                for line in changed_lines:
                    if line >= func_start_line and line <= func_end_line:
                        func_changes.append(changed_lines[line])
                        remained_changed_lines.pop(line)

                for line in all_lines:
                    if line >= func_start_line and line <= func_end_line:
                        func_all.append(all_lines[line])
                        remained_all_lines.pop(line)

                if func_changes:
                    changes.append(
                        {
                            "file": file_path,
                            "function": func_name,
                            "content_all": func_all,
                            "content_change": func_changes,
                        }
                    )

            changes.append(
                {
                    "file": file_path,
                    "function": None,
                    "content_all": remained_all_lines,
                    "content_change": remained_changed_lines,
                }
            )

        return changes

    def deserialize(self, dir, code):
        """
        :dir: the git repo dir
        :code: should follow this schema
        [
            { "path": "path/to/file", "content": "file content" },
            ...
        ]
        """
        
        for root, dirs, files in os.walk(dir):
            
            top = pathlib.Path(root).relative_to(dir).parts
            if top and top[0] == ".git":
                continue
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                if name != ".git":
                    shutil.rmtree(os.path.join(root, name))

        
        for file in code:
            file_path = pathlib.Path(os.path.join(dir, file["path"]))
            file_path.parent.mkdir(parents=True, exist_ok=True)
            with open(file_path, "w", encoding="utf-8") as f:
                f.write(file["content"])

    def get_changed_functions_in_diff(
        self, origin_code, buggy_code, *, delete_temp: bool = False
    ):
        def commit(repo: pygit2.Repository, msg: str):
            index = repo.index
            
            index.add_all()
            index.write()
            
            tree = index.write_tree()
            author = pygit2.Signature("author", "author")
            committer = pygit2.Signature("committer", "committer")
            if repo.head_is_unborn:
                parents = []
            else:
                parents = [repo.head.target]
            commit = repo.create_commit("HEAD", author, committer, msg, tree, parents)
            return str(commit)

        with tempfile.TemporaryDirectory() as temp:  
            
            repo = pygit2.init_repository(temp, initial_head="master")
            assert os.path.exists(os.path.join(temp, ".git", "objects"))
            
            self.deserialize(temp, buggy_code)
            buggy_sha = commit(repo, "buggy")
            print("buggy_sha: ", buggy_sha)
            
            self.deserialize(temp, origin_code)
            origin_sha = commit(repo, "origin")
            print("origin_sha: ", origin_sha)
            
            that = GitRepository(self.repo_url, temp)  
            return that.get_changed_functions_in_commit(origin_sha)

    def apply_patchs_to_buggycode(self, repo_path, patch_content):
        def get_diff(repo, commit_sha):
            commit = repo.get(commit_sha)
            
            if len(commit.parents) == 0:
                print(f"Commit {commit} has no parent, it might be the initial commit.")
                return []

            parent_commit = commit.parents[0]
            diff = self.repo.diff(parent_commit, commit)
            return diff

        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        

        with tempfile.TemporaryDirectory() as temp:  
            
            repo = pygit2.init_repository(temp, initial_head="master")
            assert os.path.exists(os.path.join(temp, ".git", "objects"))
            
            self.deserialize(temp, buggy_code)
            buggy_sha = commit(repo, "buggy")
            print("buggy_sha: ", buggy_sha)
            
            self.deserialize(temp, origin_code)
            origin_sha = commit(repo, "origin")
            print("origin_sha: ", origin_sha)
            
            that = GitRepository(self.repo_url, temp)  
            

        
        
        
        
        
        
        
        
        
        
        
        
        
        
        

    def extract_functions(self, node, code):
        functions = []
        if node.type == "function_definition":
            start = node.start_byte
            end = node.end_byte
            snippet = code[start:end].decode("utf-8")
            function_name = None
            for child in node.children:
                if child.type == "function_declarator":
                    function_name = child.text.decode("utf-8")
                    break
            
            functions.append(
                (snippet, node.start_point[0], node.end_point[0], function_name)
            )
            
        for child in node.children:
            functions.extend(self.extract_functions(child, code))
        return functions

    def get_changed_lines_from_patch(self, hunks):  
        changed_lines = {}
        all_lines = {}
        current_new_line_number = 0
        current_old_line_number = 0

        for hunk in hunks:
            current_new_line_number = hunk.new_start
            current_old_line_number = hunk.old_start

            for line in hunk.lines:
                if line.origin == "+":
                    current_new_line_number += 1
                elif line.origin == "-":
                    changed_lines[current_old_line_number] = line.content
                    all_lines[current_old_line_number] = line.content
                    current_old_line_number += 1
                elif line.origin == " ":
                    all_lines[current_old_line_number] = line.content
                    current_new_line_number += 1
                    current_old_line_number += 1

        return changed_lines, all_lines

    def get_token(self, code, encoding="o200k_base"):
        if self.enc is None:
            
            self.enc = tiktoken.get_encoding(encoding)
        return len(self.enc.decode(str(code).encode()))
        

    def save_to_json(self, data, json_file_path):
        with open(json_file_path, "w", encoding="utf-8") as json_file:
            json.dump(data, json_file, ensure_ascii=False, indent=4)
        print(f"Information has been saved to {json_file_path}")

    def load_from_json(self, json_file_path):
        with open(json_file_path, "r", encoding="utf-8") as json_file:
            data = json.load(json_file)
        print(f"Information has been loaded from {json_file_path}")
        return data

    def generate_diff(self, buggy_code, origin_code):
        buggy_dict = {file["path"]: file["content"] for file in buggy_code}
        origin_dict = {file["path"]: file["content"] for file in origin_code}

        patch = ""

        for path in origin_dict:
            if path in buggy_dict:
                if origin_dict[path] != buggy_dict[path]:
                    diff = difflib.unified_diff(
                        buggy_dict[path].splitlines(keepends=True),
                        origin_dict[path].splitlines(keepends=True),
                        fromfile=f"a/{path}",
                        tofile=f"b/{path}",
                    )
                    
                    patch += "".join(diff)
            else:
                diff = difflib.unified_diff(
                    [],
                    origin_dict[path].splitlines(keepends=True),
                    fromfile=f"a/{path}",
                    tofile=f"b/{path}",
                )
                patch += "".join(diff)

        for path in buggy_dict:
            if path not in origin_dict:
                diff = difflib.unified_diff(
                    buggy_dict[path].splitlines(keepends=True),
                    [],
                    fromfile=f"a/{path}",
                    tofile=f"b/{path}",
                )
                patch += "".join(diff)

        return patch

    def get_file_contents_from_dir(self, dir_path, max_file_size=1024 * 1024):
        files_info = []

        for root, _, files in os.walk(dir_path):
            for file in files:
                file_path = os.path.join(root, file).replace("\\", "/")
                if (
                    ".jpg" in file or ".png" in file or ".JPG" in file
                ):  
                    content = "data:image/<binary content>"
                    continue
                try:
                    with open(file_path, "rb") as f:
                        data = f.read()
                        if len(data) > max_file_size:
                            content = self.get_partial_content(
                                file_path, max_length=1024
                            )
                            content = f"File too large to display fully. Size: {len(data)} bytes. The following are the beginning and end contents of the file:\n{content}"
                        else:
                            content = data.decode("utf-8", errors="ignore")
                            content = re.sub(
                                r"data:image\/[a-zA-Z]+;base64,[a-zA-Z0-9+/=]+",
                                "data:image/<binary content>",
                                content,
                            )
                except UnicodeDecodeError:
                    content = "<binary content>"
                except Exception as e:
                    content = f"Error reading file: {str(e)}"

                files_info.append(
                    {
                        "path": file_path.replace("buggycode/", "").replace(
                            "origincode/", ""
                        ),
                        "content": content,
                    }
                )

        return files_info


def main():
    repo_url = "https://github.com/"
    clone_dir = "tmp"
    specific_commit_sha = "1234"
    json_file_path = "file_content.json"

    print("Initializing GitRepository object...")
    git_repo = GitRepository(repo_url, clone_dir)
    patch_file_path = "commit_changes_func.json"
    print(f"Getting changed functions for commit {specific_commit_sha}...")
    
    
    
    
    
    
    
    
    
    str_a = git_repo.get_file_contents_from_dir("tmp\")
    
    print(str_a)
    
    
    
    
    
    
    
    
    


if __name__ == "__main__":
    main()
