"""
The main driver.
"""
import json
import logging
import shutil
from argparse import ArgumentParser
from collections.abc import Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from itertools import chain
from os.path import abspath
from os.path import join as pjoin
from pathlib import Path
from loguru import logger
import os
import subprocess
from app.config import config
from app import inference, log, task_counter
from app import utils as apputils
from app.model import common
from app.model.register import register_all_models
from app.post_process import (
    get_final_patch_path,
)
from app.task.raw_tasks import RawLocalTask, RawTask
from app.task.task import Task


def initialize_git_repo(repo_path: str):
    if not os.path.isdir(repo_path):
        raise ValueError(f"Provided path {repo_path} is not a directory.")
    commands = [
        ["git", "init"],
        ["git", "add", "."],
        ["git", "commit", "-m", "Initial commit"]
    ]
    # Skip if git is already initialized
    if os.path.exists(os.path.join(repo_path, ".git")):
        # print(f"Git repository already exists at {repo_path}")
        commands = commands[1:] # skip git init
    # check whether there is modification
    status = subprocess.run(['git', 'status', '--porcelain'], cwd=repo_path, stdout=subprocess.PIPE, text=True)
    
    if status.stdout.strip(): # there is modification
        for cmd in commands:
            subprocess.run(
                cmd, 
                cwd=repo_path, 
                capture_output=True,
                text=True,
                check=True
            )
    else:
        print(f"repo {repo_path} clean.")


def recover_bk(task_obj):
    source = "repos_bk"
    destination = "repos"
    project_name = task_obj["project_name"]
    source = os.path.join(source, project_name)
    destination = os.path.join(destination, project_name)
    if os.path.exists(destination):
        shutil.rmtree(destination)
    shutil.copytree(source, destination)
    initialize_git_repo(destination)
    print(f"Successfully recovered project {project_name}\n")


def prepare_codebase_repoeval(repoeval_task_json: str, local_repo: str):
    
    def recover_bk(task_obj):
        source = "code-rag-bench/generation/scripts/repoeval_bk"
        destination = "code-rag-bench/generation/scripts/repoeval"
        target_fpath = '/'.join(task_obj["metadata"]['fpath_tuple'])
        project_name = task_obj["metadata"]['fpath_tuple'][0]
        source = os.path.join(source, project_name)
        destination = os.path.join(destination, project_name)
        if os.path.exists(destination):
            shutil.rmtree(destination)
        shutil.copytree(source, destination)
        print(f"Successfully recovered project {project_name}\n")
    
    def remove_target_func_body(file_content_list: list, whole_target_func: list, replacement: list, start_line_num: int):
        n = len(whole_target_func)
        origin_target_conten_list = file_content_list.copy()
        if origin_target_conten_list[start_line_num:start_line_num+n] == whole_target_func:
            print("removing target")
            removed_list = origin_target_conten_list[:start_line_num] + replacement + origin_target_conten_list[start_line_num+n:]
            return removed_list # list
        else:
            if ''.join(whole_target_func) not in ''.join(origin_target_conten_list):
                print('target already removed')
                return origin_target_conten_list
            else:
                print('fail to remove target from file! Exit here.')
                exit()
 
    
    with open(repoeval_task_json, 'r') as f:
        task_data = json.load(f)
    task_id = task_data["metadata"]["task_id"]
    project_name = task_id.split('/')[0]
    func_name = task_data["metadata"]["function_name"]
    rel_target_path = '/'.join(task_data["metadata"]["fpath_tuple"][1:])
    abs_target_path = os.path.join(local_repo, rel_target_path)
    # recover repo
    recover_bk(task_data)
    # replace GT
    with open(abs_target_path, 'r') as f: 
        file_content_list = f.readlines()
    whole_target_func_list = task_data["metadata"]["ground_truth"].splitlines(keepends=True)
    # get correct_indent
    correct_indent = 0
    for ele in whole_target_func_list:
        if len(ele.strip()) > 0:
            correct_indent = len(ele) - len(ele.lstrip())
            break
    replacement = [f"{' ' * correct_indent}raise NotImplementedError\n"]
    start_line_num = task_data["metadata"]["line_no"]

    edited_content_list = remove_target_func_body(
        file_content_list,
        whole_target_func_list,
        replacement,
        start_line_num
    )
    with open(abs_target_path, 'w') as f: 
        file_content_list = f.writelines(edited_content_list)
    # git init local repo
    initialize_git_repo(local_repo)
    print(f'{task_id}: func {func_name} at {start_line_num} in {abs_target_path} is ready.')


def prepare_codebase_repocod(repoeval_task_json: str, local_repo: str):
    
    def remove_target_func_body(file_content_list: list, whole_target_func: list, replacement: list, start_line_num: int):
        n = len(whole_target_func)
        origin_target_conten_list = file_content_list.copy()
        start_line_num -= 1
        if origin_target_conten_list[start_line_num:start_line_num+n] == whole_target_func:
            print("removing target")
            removed_list = origin_target_conten_list[:start_line_num] + replacement + origin_target_conten_list[start_line_num+n:]
            return removed_list # list
        else:
            if ''.join(whole_target_func) not in ''.join(origin_target_conten_list):
                print('target already removed')
                return origin_target_conten_list
            else:
                print('fail to remove target from file! Exit here.')
                exit()
       
    with open(repoeval_task_json, 'r') as f:
        task_data = json.load(f)
    func_name = task_data["function_name"]
    rel_target_path = task_data["target_module_path"]
    abs_target_path = os.path.join(local_repo, rel_target_path)
    # recover repo
    recover_bk(task_data)
    # replace GT
    with open(abs_target_path, 'r') as f: 
        file_content_list = f.readlines()
    whole_target_func_list = task_data["full_function"]
    prompt: list = task_data["prompt"]
    pure_func_body= whole_target_func_list[len(prompt):]
    # get correct_indent
    correct_indent = 0
    for ele in pure_func_body:
        if len(ele.strip()) > 0:
            correct_indent = len(ele) - len(ele.lstrip())
            break
    replacement = task_data["prompt"].copy()
    replacement.append(f"{' ' * correct_indent}raise NotImplementedError\n")
    start_line_num = task_data["start_line"] # 1-based
    edited_content_list = remove_target_func_body(
        file_content_list,
        whole_target_func_list,
        replacement,
        start_line_num
    )
    if edited_content_list == file_content_list:
        print('edited_content_list == file_content_list! EXIT!')
        exit()
        
    with open(abs_target_path, 'w') as f: 
        file_content_list = f.writelines(edited_content_list)
    # git init local repo
    initialize_git_repo(local_repo)
    task_id = str(os.path.basename(repoeval_task_json))
    print(f'{task_id}: func {func_name} at {start_line_num} in {abs_target_path} is ready.')



def main():
    register_all_models()
    parser = ArgumentParser()

    subparser_dest_attr_name = "command"
    subparsers = parser.add_subparsers(dest=subparser_dest_attr_name)

    local_parser = subparsers.add_parser("local-issue", help="Run a local issue.")
    set_local_parser_args(local_parser)

    args = parser.parse_args()

    ## common options
    config.output_dir = args.output_dir
    if config.output_dir is not None:
        config.output_dir = abspath(config.output_dir)
    
    # we can change the number of num_processes here
    num_processes: int = int(args.num_processes) # default: 8
    # set whether brief or verbose log
    print_stdout: bool = not args.no_print
    log.print_stdout = print_stdout
    

    # model related
    config.models = list(chain.from_iterable(args.model))
    if not config.models:
        config.models.append("gpt-4o")
    common.set_model(config.models[0])

    # FIXME: make temperature part of the Model class
    common.MODEL_TEMP = args.model_temperature
    config.conv_round_limit = args.conv_round_limit
    config.only_reproduce = args.reproduce

    subcommand = getattr(args, subparser_dest_attr_name)
    # only try local-issue
    
    if subcommand == "local-issue":
        local_repo = args.local_repo
        if local_repo is not None:
            local_repo = abspath(local_repo)
        issue_file = args.issue_file
        if issue_file is not None:
            issue_file = abspath(issue_file)
        task_data_path = args.task_data_path
        if task_data_path is not None:
            task_data_path = abspath(task_data_path)
        task_id = args.task_id
        if config.dataset_name == "RepoEval":
            prepare_codebase_repoeval(task_data_path, local_repo)
        if config.dataset_name == 'RepoCod':
            prepare_codebase_repocod(task_data_path, local_repo)
        task = RawLocalTask(
            task_id,
            local_repo,
            issue_file,
            task_data_path
        )
        groups = {"local": [task]}
        run_task_groups(groups, num_processes)
        
        with open(task_data_path, 'r') as f:
            task_obj = json.load(f)
            
        recover_bk(task_obj)


def set_local_parser_args(parser: ArgumentParser) -> None:
    add_task_related_args(parser)
    parser.add_argument(
        "--task-id", type=str, help="Assign an id to the current local issue task."
    )
    parser.add_argument(
        "--task-data-path", type=str, help="Path to json data for this task."
    )
    # task_data example: REPOCOD/Tasks4Agents/astropy/astropy_0/astropy_0.json
    # {
    #     "target_module_path": str, # this is a must，target file relative path
    #     "prompt": list[str],
    #     "relavent_test_path": str, # abs path in docker
    #     "full_function": list[str], # GT
    #     "function_name": str,
    #     "project_name": "astropy", # str
    #     "container_name": "repocod_astropy", # str
    #     "start_line": 123, # int 1-based
    #     "end_line": 161, # int, 1-based
    #     "filtered_test_dict": { # this is a must
    #         "0": list[dict], # function is not called in test func. each dict: {test_path_docker, test_path_local, src_code}
    #         "1": list[dict], # possibly called the function, but may not from the same class
    #         "2": list[dict], # called the target function
    #     }
    # }
    
    parser.add_argument(
        "--local-repo", type=str, help="Path to a local copy of the target repo."
    )
    parser.add_argument("--issue-file", type=str, help="Path to a local issue file.")
    # parser.add_argument("--env-name", type=str, default="", help="the conda environment name")


def add_task_related_args(parser: ArgumentParser) -> None:
    parser.add_argument(
        "--output-dir",
        type=str,
        help="Path to the directory that stores the run results.",
    )
    parser.add_argument(
        "--no-print",
        action="store_true",
        default=False,
        help="Do not print most messages to stdout.",
    )
    parser.add_argument(
        "--model",
        type=str,
        choices=list(common.MODEL_HUB.keys()),
        nargs="+",
        action="append",
        help="The model to use.",
    )
    parser.add_argument(
        "--model-temperature",
        type=float,
        default=0.0,
        help="The model temperature to use, for OpenAI models.",
    )
    parser.add_argument(
        "--conv-round-limit",
        type=int,
        default=15,
        help="Conversation round limit for the main agent.",
    )
    parser.add_argument(
        "--enable-layered",
        action="store_true",
        default=True,
        help="Enable layered code search.",
    )
    parser.add_argument(
        "--reproduce",
        action="store_true",
        default=False,
        help="Special mode to only generate reproducer tests",
    )
    parser.add_argument(
        "--num-processes",
        type=str,
        default=8,
        help="Number of processes to run the tasks in parallel.",
    )



def run_task_groups(
    task_groups: Mapping[str, Sequence[RawTask]],
    num_processes: int,
):
    """
    Main entry for running tasks.
    """
    all_tasks = list(chain.from_iterable(task_groups.values()))    
    num_tasks = len(all_tasks)

    task_counter.init_total_num_tasks(num_tasks)

    # print some info about task
    log.print_with_time(f"Total number of tasks: {num_tasks}")
    log.print_with_time(f"Total number of processes: {num_processes}")
    log.print_with_time(f"Task group info: (number of groups: {len(task_groups)})")
    for key, tasks in task_groups.items():
        log.print_with_time(f"\t{key}: {len(tasks)} tasks")
        
    # single process mode
    if num_processes == 1:
        log.print_with_time("Running in single process mode.")
        run_tasks_serial(all_tasks)
        log.print_with_time("Finished all tasks sequentially.")
    else:
        run_task_groups_parallel(task_groups, num_processes)

    if config.only_save_sbfl_result: # not executed
        log.print_with_time("Only saving SBFL results. Exiting.")
        return


def run_tasks_serial(tasks: list[RawTask]) -> None:
    for task in tasks:
        run_task_in_subprocess(task)


def run_task_groups_parallel(
    task_groups: Mapping[str, Sequence[RawTask]],
    num_processes: int,
):
    num_task_groups = len(task_groups)
    task_counter.init_total_num_task_groups(num_task_groups)
    num_processes = min(num_processes, num_task_groups)
    task_group_ids_items = sorted(
        task_groups.items(),
        key=lambda x: len(x[1]),
        reverse=True,
    )
    log.print_with_time(f"Sorted task groups: {[x[0] for x in task_group_ids_items]}")
    try:
        # Use ProcessPoolExecutor instead of multiprocessing.Pool,
        # to support nested sub-processing
        group_ids, group_tasks = zip(*task_group_ids_items)
        with ProcessPoolExecutor(num_processes) as executor:
            executor.map(run_task_group, group_ids, group_tasks)
    finally:
        log.print_with_time("Finishing all tasks in the pool.")


def run_task_group(task_group_id: str, task_group_items: list[RawTask]) -> None:
    """
    Run all tasks in a task group sequentially.
    Main entry to parallel processing.
    """
    log.print_with_time(
        f"Starting process for task group {task_group_id}. Number of tasks: {len(task_group_items)}."
    )
    for task in task_group_items:
        # within a group, the runs are always sequential
        run_task_in_subprocess(task)
        log.print_with_time(task_counter.incre_task_return_msg())

    log.print_with_time(
        f"{task_counter.incre_task_group_return_msg()} Finished task group {task_group_id}."
    )


def run_task_in_subprocess(task: RawTask) -> None:
    with ProcessPoolExecutor(max_workers=1) as executor:
        executor.submit(run_raw_task, task)


def run_raw_task(task: RawTask) -> bool:
    """
    High-level entry for running one task.
    Args:
        - task: The Task instance to run.
    Returns:
        Whether the task completed successfully. A bool value
    """

    task_id = task.task_id
    start_time_s = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    task_output_dir = pjoin(config.output_dir, f"{task_id}_{start_time_s}")
    print(task_output_dir)
    apputils.create_dir_if_not_exists(task_output_dir)

    task.dump_meta_data(task_output_dir)

    log.log_and_always_print(
        f"============= Running task {task_id} =============",
    )

    run_ok = False

    try:
        run_ok = do_inference(task.to_task(), task_output_dir)
        if run_ok:
            run_status_message = f"Task {task_id} completed successfully."
        else:
            run_status_message = f"Task {task_id} failed without exception."
    except Exception as e:
        logger.exception(e)
        run_status_message = f"Task {task_id} failed with exception: {e}."

    log.log_and_always_print(run_status_message)

    final_patch_path = get_final_patch_path(task_output_dir)
    if final_patch_path is not None:
        log.log_and_always_print(
            f"Please find the generated patch at: {final_patch_path}"
        )
        # since we jumped the patch selection, we can directly save the final_patch_path into a json file
    else:
        log.log_and_always_print("No patch generated. You can try running ACR again.")
    return run_ok



def do_inference(python_task: Task, task_output_dir: str) -> bool:
    apputils.create_dir_if_not_exists(task_output_dir)

    log_file_name = "info.log"
    logger.add(
        pjoin(task_output_dir, log_file_name),
        level="DEBUG",
        format=(
            "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level>"
            " | <level>{message}</level>"
        ),
    )

    start_time = datetime.now()
    python_task.setup_project()
    run_ok = False
    try:
        # actually running the task in inference.run_one_task
        try: 
            run_ok = inference.run_one_task(
                python_task, task_output_dir, config.models
            )

        except common.ClaudeContentPolicyViolation: # run time error only.
            log.log_and_always_print(
                "Content policy violation. Retry with backup model."
            )

            # retry with backup model
            python_task.setup_project()

            # remove everything other than the info.log file, and
            # also some meta data file dumped by RawTask
            log.log_and_always_print(
                "Removing all files except info.log and meta files."
            )

            for f in Path(task_output_dir).iterdir():
                if f.is_file() and f.name not in [
                    log_file_name,
                    "meta.json",
                    "problem_statement.txt",
                    "developer_patch.diff",
                ]:
                    f.unlink()
                if f.is_dir():
                    shutil.rmtree(str(f))

            run_ok = inference.run_one_task(
                python_task, task_output_dir, config.backup_model
            )

        # except Exception as e: # comment this part when debugging
        #     log.log_and_always_print(
        #         f"Exception happened when running inference.run_one_task: {e}"
        #     )
        end_time = datetime.now()
        dump_cost(start_time, end_time, task_output_dir)
    finally:
        python_task.reset_project()
    return run_ok


def dump_cost(
    start_time: datetime,
    end_time: datetime,
    task_output_dir: str,
):
    model_stats = common.SELECTED_MODEL.get_overall_exec_stats()
    stats = {
        "commit": apputils.get_current_commit_hash(),
        "start_epoch": start_time.timestamp(),
        "end_epoch": end_time.timestamp(),
        "elapsed_seconds": (end_time - start_time).total_seconds(),
    }
    stats.update(model_stats)

    with open(pjoin(task_output_dir, "cost.json"), "w") as f:
        json.dump(stats, f, indent=4)


if __name__ == "__main__":
    logging.getLogger("httpx").setLevel(logging.WARNING)
    logger.remove()
    main()
