import json
from collections import defaultdict
from collections.abc import Iterable
from itertools import cycle
from os import PathLike
from os.path import samefile
from pathlib import Path
from shutil import copy2

from loguru import logger
from natsort import natsorted

from app.config import config
from app.agents import agent_select
from app.agents.agent_tester import NoReproductionStep, TestAgent
from app.api import validation
from app.api.review_manage import ReviewManager
from app.api.validation import evaluate_patch
from app.data_structures import BugLocation
from app.log import *
from app.manage import ProjectApiManager
from app.model.common import set_model
from app.task.task import Task
from app.data_structures import MessageThread


def  write_patch_iterative_with_review(
    task: Task,
    output_dir: str,
    review_manager: ReviewManager,
    retries=3,
) -> tuple[bool, str]: 
    logger.info("Start generating patches with refinement")
    # print_banner(f"Start Generation with Tester")
    patch_gen = review_manager.generator() # this is edited into: 
    # 1. generate patch and pick test cases
    # 2. provide test cases feedback and refine
    eval_summary = None
    for _ in range(retries):
        try:
            patch_handle, patch_content = patch_gen.send(eval_summary)
            logger.info("Reviewer approved patch: {}", patch_handle)
        except StopIteration:
            break

        logger.info("Begin evaluating patch: {}", patch_handle)
        eval_passed, eval_summary = validation.evaluate_patch(
            task, patch_handle, patch_content, output_dir
        )
        if eval_passed:
            patch_gen.close()
            logger.info(
                "Patch {} passed evaluation. Ending patch generation", patch_handle
            )
            return True, patch_content

        logger.info("Patch {} failed evaluation", patch_handle)

    return False, None
    


def write_patch_iterative(
    task: Task,
    output_dir: str,
    review_manager: ReviewManager,
    retries=3,
) -> tuple[bool, str, str]:
    logger.info("Start generating patches first")
    patch_gen = review_manager.patch_only_generator()
    for _ in range(retries):
        try:
            # patch_handle may be the patch ID, and now patch_content is no longger in a git patch format. just a pure code block.
            patch_handle, patch_content, response = patch_gen.send(None)
            logger.info("Generated applicable patch: {}", patch_handle)
        except StopIteration:
            break
        logger.info("Begin evaluating patch: {}", patch_handle)
        eval_passed, _ = validation.evaluate_patch(
            task, patch_handle, patch_content, output_dir
        )
        if eval_passed:
            patch_gen.close()
            logger.info(
                "Patch {} passed evaluation. Ending patch generation", patch_handle
            )
            return True, patch_content, response

        logger.info("Patch {} failed evaluation", patch_handle)

    return False, None, None



def run_one_task(task: Task, output_dir: str, model_names: Iterable[str]) -> bool:
    """
    Main entry point to run inference on one task.
    We first get 3 Args:
        output_dir (str): Path to the output directory.
        api_manager (ProjectApiManager): The already-initialized API manager.
        problem_stmt (str): The original problem statement submitted to the task issue.
    """
    assert model_names
    model_name_cycle = cycle(model_names) # able to loop each name in the model_names

    for idx in range(config.overall_retry_limit): # run task for 3 times in max
        model_name = next(model_name_cycle)
        set_model(model_name)
        logger.info("Starting overall retry {} with model {}", idx, model_name)
        out_dir = Path(output_dir, f"output_{idx}")
        out_dir.mkdir(parents=True, exist_ok=True)

        # meta.json is used later by convert_response_to_diff(),
        # so it needs to be copied over
        meta_file = Path(output_dir, "meta.json")
        if meta_file.exists():
            copy2(meta_file, out_dir)

        api_manager = ProjectApiManager(task, str(out_dir))
        
        # api_manager.task # for logging of this task instance
        # api_manager.output_dir # where to write our output
        # api_manager.search_manager # build search manager
        # api_manager.tool_call_layers # record layered API calls
        
        if _run_one_task(str(out_dir), api_manager, task.get_issue_statement()): 
            logger.info("Overall retry {} succeeded; ending workflow", idx)
            break
        
        logger.info("Overall retry {} failed; proceeding to next retry", idx)

    log_and_print("=====Starting patch selection=====")
    print('skip the patch selection for now')
    return True


def select_patch(task: Task, output_dir: str | PathLike) -> tuple[str, dict]:

    patches = natsorted(list(Path(output_dir).glob("**/extracted_patch_*.diff")))

    # TODO: These candidate patches must have been dismissed by reviewer. Maybe an
    # assertion should be added to confirm this.
    candidate_patches = [p for p in patches if may_pass_regression_tests(task, p)]

    agent_comment = None
    thread = None

    for p in candidate_patches:
        index = p.with_suffix("").name.rpartition("_")[2]
        reviews = natsorted(
            list(p.parent.glob(f"review_p{index}_t*.json")), reverse=True
        )
        if not reviews:
            continue
        assert len(reviews) == 1, p
        if json.loads(reviews[0].read_text())["patch-correct"] == "yes":
            last_patch = natsorted(patches)[-1]
            assert samefile(
                p, last_patch
            ), f"{p} is approved and passes validation, but the last patch was {last_patch}"
            selected_patch = p
            reason = "reviewer-approved"
            break
    else:
        if len(candidate_patches) > 1:
            content_to_indices = defaultdict(list)
            for idx, p in enumerate(candidate_patches):
                content_to_indices[p.read_text()].append(idx)
            items = sorted(
                content_to_indices.items(),
                key=lambda item: (len(item[1]), -item[1][0]),
                reverse=True,
            )

            # if len(items[0]) > 1:
            if False:
                index = items[0][1][0]
                selected_patch = candidate_patches[index]
                reason = "majority,multiple-pass-regression"
            else:
                try:
                    index, agent_comment, thread = agent_select.run(
                        task.get_issue_statement(),
                        [p.read_text() for p in candidate_patches],
                    )
                    reason = "agent-selected,multiple-pass-regression"
                except Exception:
                    index = -1
                    reason = "agent-error,multiple-pass-regression"
                selected_patch = candidate_patches[index]
        elif len(candidate_patches) == 1:
            selected_patch = candidate_patches[0]
            reason = "no-agent,single-pass-regression"
        else:
            content_to_indices = defaultdict(list)
            for idx, p in enumerate(patches):
                content_to_indices[p.read_text()].append(idx)
            items = sorted(
                content_to_indices.items(),
                key=lambda item: (len(item[1]), -item[1][0]),
                reverse=True,
            )

            # if len(items[0]) > 1:
            if False:
                index = items[0][1][0]
                selected_patch = patches[index]
                reason = "majority,none-pass-regression"
            else:
                try:
                    index, agent_comment, thread = agent_select.run(
                        task.get_issue_statement(), [p.read_text() for p in patches]
                    )
                    reason = "agent-selected,none-pass-regression"
                except Exception:
                    index = -1
                    reason = "agent-error,none-pass-regression"
                selected_patch = patches[index]

    rel_selected_patch = str(selected_patch.relative_to(output_dir))

    result = {
        "selected_patch": rel_selected_patch,
        "reason": reason,
    }

    if agent_comment is not None:
        result["agent_comment"] = agent_comment

    if thread is not None:
        thread.save_to_file(Path(output_dir, "agent_selection.json"))

    return str(selected_patch.relative_to(output_dir)), result


def may_pass_regression_tests(task: Task, patch_file: str | PathLike) -> bool:
    if not config.enable_validation:
        return True

    patch_file = Path(patch_file)

    patch_idx = patch_file.with_suffix("").name.rpartition("_")[2]

    regression_file = patch_file.with_name(f"regression_{patch_idx}.json")
    if regression_file.exists():
        return json.loads(regression_file.read_text())["no_additional_failure"]

    task.reset_project()
    pass_evaluation, _ = evaluate_patch(
        task, patch_idx, patch_file.read_text(), str(patch_file.parent)
    )

    return pass_evaluation


def _run_one_task(
    output_dir: str, api_manager: ProjectApiManager, problem_stmt: str
) -> bool:
    print_banner("Starting AutoCodeRover on the following task")
    print_issue(problem_stmt)
    
    repro_result_map = {}
    print_banner(f"Find Generation Location")
    bug_locs = api_manager.search_manager.locate_target_location() # BugLocation
    bug_locs_str = str(bug_locs.to_dict())
    print_acr(bug_locs_str, f"Code Generation Location")
    bug_locs = [bug_locs]
    test_agent = TestAgent(bug_locs, api_manager.task, output_dir)
    # after locate the target function, do context search
    print_banner(f"Start Context Search")
    search_msg_thread = api_manager.search_manager.search_iterative()
    # search_msg_thread is the trajectory of the task. 
    # done with search; dump the tool calls used for recording
    api_manager.search_manager.dump_tool_call_layers_to_file()
    # Write patch
    logger.debug("Gathered enough information. Invoking write_patch.")
    print_banner(f"Start Writing Patch")
    search_msg_thread.save_to_file(
                Path(output_dir, f"_conversation_in_rag.json")
            )
    review_manager = ReviewManager(
        search_msg_thread,
        bug_locs,
        api_manager.search_manager,
        api_manager.task,
        output_dir,
        test_agent,
        repro_result_map,
    )
    task_test_data = api_manager.task.test_data
    
    # guarantee the there are test cases for this task.
    has_test = False
    for call_dis, test_list in task_test_data.items():
        if len(test_list) > 0:
            has_test = True
            break
    if config.test_in_refine and has_test: 
        # do the testing. If it's false, the pipeline would be only: agent RAG + gen code. However, if tasks test is unparseable, just do it without test.
        print('HUYIRAN: write patch with test refinement')
        try:
            result, patch_content = write_patch_iterative_with_review(
                api_manager.task, output_dir, review_manager
            )
            
        # this exception can arise when writing new reproducers
        except NoReproductionStep:
            log_and_print(f'HUYIRAN: no steps for testing the task.')
            pass
        # except Exception as e:
        #     log_and_print(f'HUYIRAN: Exception happened when write_patch_iterative:\n{e}')
        #     pass
    else:
        print('HUYIRAN: write patch without test refinement')
        if(not has_test):
            print('HUYIRAN: generate patch without test refinement due to no parseable test')
        result, patch_content, response = write_patch_iterative(api_manager.task, output_dir, review_manager) # only gen code
    
    # logger.info(
    #     "Invoked write_patch. Since no testing is required, the workflow will be terminated."
    # )
    return result # result is a bool value
        


if __name__ == "__main__":
    pass
    # from app.raw_tasks import RawSweTask

    # config.enable_validation = True

    # applicable_path = Path(
    #     "/media/media0/haifeng/projects/reverse-prompt/acr-plus/experiment/06-13-docker-val-loop-lite-try-2-rand/applicable_patch/"
    # )
    # task_dirs = list(applicable_path.glob("*"))
    # for task_dir in task_dirs:
    #     meta = json.loads(task_dir.joinpath("meta.json").read_text())
    #     raw_task = RawSweTask(meta["task_id"], meta["setup_info"], meta["task_info"])
    #     task = raw_task.to_task()
    #     selected_patch, reason = select_patch(task, task_dir)

    #     task_dir.joinpath("selected_patch.json").write_text(
    #         json.dumps({"selected_patch": selected_patch, "reason": reason}, indent=4)
    #     )
