import json
import re
from collections import defaultdict
from collections.abc import Generator
from copy import deepcopy
from pathlib import Path
from typing import TypeAlias
import random
from loguru import logger
from tenacity import retry, stop_after_attempt
from app.task.task import *
from app.task.raw_tasks import *
from app.utils import *
from app.agents.agent_common import InvalidLLMResponse
from typing import Dict, List, Optional
from app.data_structures import BugLocation, MessageThread, ReproResult
from app.log import print_acr, print_reproducer
from app.model.gpt import common
from app.task.task import Task
from app.config import config
import subprocess
import shlex


SYSTEM_PROMPT = (
    "You are an experienced software engineer responsible for testing the following target functions"
)
INITIAL_REQUEST = (
    "Please try to write a standalone python file `self_tester.py` to test"
    " the target function. Put the file in a code block.\n\n"
    "The file would be put in the root directory of the project and executed"
    " by `python3 self_tester.py`. The script should raise an `AssertionError` when"
    " the target function does not achieve the intended functionality and print a stack trace of the error."
    "The script should also exit with code 0 when the target function pass the test cases.\n\n"
    "Please use the following function to print the stack trace, so that the line numbers"
    " of the statements are shown clearly:\n"
    "```\n"
    "def print_stacktrace(e: Exception):\n"
    "    import traceback"
    "    import sys"
    "    tb = traceback.extract_tb(e.__traceback__)\n"
    '    print("Traceback (most recent call last):", file=sys.stderr)\n'
    "    for frame in tb:\n"
    "        line_number = frame.lineno\n"
    '        code_context = frame.line.strip() if frame.line else "Unknown"\n'
    "        print(f'  File \"{frame.filename}\"', file=sys.stderr)\n"
    '        print(f"    {line_number}: {code_context}", file=sys.stderr)\n'
    '    print(f"{e.__class__.__name__}: {e}", file=sys.stderr)\n'
    "```\n"
    "Remember, In the begining, the function body of the target function is empty."
)


class NoReproductionStep(RuntimeError):
    """Raised when issue statement does not contain steps for reproduction."""
    pass


TestHandle: TypeAlias = str


class TestAgent:
    def __init__(self, buglocs: BugLocation, task: RawLocalTask, task_dir: str) -> None:
        self.task = task
        self.task_dir = task_dir
        self.buglocs_list = buglocs
        self._request_idx: int = -1
        self._responses: dict[TestHandle, str] = {}
        self._tests: dict[TestHandle, str] = {}
        self._feedbacks: dict[TestHandle, list[str]] = defaultdict(list)
        self._history: list[TestHandle] = []
        self._non_repro_history: list[TestHandle] = []
        
        
        self.selected_test_cases_list: list[dict] = []
        self.test_cases_output_list: list[dict] = []
        self.test_parsed_feedback_list: list[str] = []
        self.general_feedback: str = ""
        

    def write_reproducing_test_without_feedback(
        self, retries: int = 3
    ) -> tuple[TestHandle, str, ReproResult]:
        return self._write_reproducing_test(num_feedbacks=1, retries=retries)

    def write_reproducing_test_with_feedback(
        self, max_feedbacks: int = 1, retries: int = 5
    ) -> tuple[TestHandle, str, ReproResult]:
        return self._write_reproducing_test(
            num_feedbacks=max_feedbacks, retries=retries
        )

    def add_feedback(self, handle: TestHandle, feedback: str) -> None:
        if handle not in self._tests:
            raise ValueError("patch {} does not exist", handle)

        self._feedbacks[handle].append(feedback)

    def _write_reproducing_test(
        self, num_feedbacks: int, retries: int
    ) -> tuple[TestHandle, str, ReproResult]:

        for _ in range(retries):
            feedback_handles = self._select_feedback_handles(num_feedbacks)
            response, test_content, thread = self._write_test(feedback_handles, self.task)
            self._request_idx += 1
            print_reproducer(response)
            Path(self.task_dir, f"test_raw_{self._request_idx}.md").write_text(response)
            thread.save_to_file(
                Path(self.task_dir, f"conv_test_{self._request_idx}.json")
            )

            if test_content is None:
                continue
            
            # type(self.task) is <class 'app.task.PlainTask'>
            repro_result = self.task.execute_reproducer(test_content)

            print_acr(str(repro_result)) 
            # now, as long as conda env for project is set up, we will get the result.
            # However, we should gen the patch and apply it first, then do the testing.

            if repro_result.reproduced:
                handle = self._register_reproducing_test(response, test_content)
                return handle, test_content, repro_result

            handle = self._register_non_reproducing_test(
                response, test_content, repro_result
            )
            logger.info("registered non reproducing test {}", handle)

        raise InvalidLLMResponse(
            f"Failed to write a test for target function in {retries} attempts"
        )

    @classmethod
    def _issue_has_reproduction_steps(
        cls, issue_statement: str
    ) -> tuple[bool, MessageThread]:
        prefix_thread = MessageThread()
        prefix_thread.add_system(SYSTEM_PROMPT)

        prefix_thread.add_user(f"Here is a code generation task:\n\n{issue_statement}")

        key = "is-testable"
        prefix_thread.add_user(
            "Tell me whether the target function in this task can be tested."
            "Remember, before you generate the code, "
            "the function body of the target function in the project is empty."
            "If the execution results indicates the target function is empty"
            "or some packages in the enviroment are missing, you can still sometimes consider it testable."
            "We don't have to be very strict about the test file."
            "Your answer should take the following Json format:\n"
            "```\n"
            "{\n"
            f'    "{key}": ...\n'
            "}\n"
            "```\n"
            f'where "{key}" should be either `true` or `false`.'
        )

        @retry(stop=stop_after_attempt(3))
        def query_and_parse():
            response, *_ = common.SELECTED_MODEL.call(
                prefix_thread.to_msg(), response_format="json_object"
            )

            result = json.loads(response)[key]

            if not isinstance(result, bool):
                raise InvalidLLMResponse

            thread = deepcopy(prefix_thread)
            thread.add_model(response)

            return result, thread

        return query_and_parse()

    def _select_feedback_handles(self, max_num_feedbacks: int) -> list[TestHandle]:
        if 0 <= max_num_feedbacks <= len(self._history):
            return self._history[-max_num_feedbacks:]
        elif max_num_feedbacks <= len(self._history) + len(self._non_repro_history):
            num_non_repro = max_num_feedbacks - len(self._history)
            return [
                *self._non_repro_history[-num_non_repro:],
                *self._history,
            ]
        else:
            return [*self._non_repro_history, *self._history]

    def _write_test(
        self, history_handles: list[TestHandle] | None = None,
        task: RawLocalTask | None = None
    ) -> tuple[str, str | None, MessageThread]:
        history_handles = history_handles or []

        thread = self._construct_init_thread()
        project_path = os.path.join(config.REPOCOD_WORK_DIR, task.project_name)
        init_request = INITIAL_REQUEST.replace("The file would be put in the root directory of the project and executed", f"The file would be put in the root directory of the project: `{project_path}` and executed")
        if any(handle in self._feedbacks for handle in history_handles):
            thread.add_user(init_request)
        for handle in history_handles:
            if feedbacks := self._feedbacks.get(handle, []):
                thread.add_model(self._responses[handle], [])
                for feedback in feedbacks:
                    thread.add_user(feedback)
            else:
                logger.warning("test {} does not have a feedback; skipping", handle)
        thread.add_user(init_request)

        if not history_handles:
            print_acr(init_request)
        # print(thread)
        # exit()
        response, *_ = common.SELECTED_MODEL.call(thread.to_msg())

        return response, self.convert_response_to_test(response), thread

    def _construct_init_thread(self) -> MessageThread:
        thread = MessageThread()
        thread.add_system(SYSTEM_PROMPT)

        prompt = f"Here is an issue:\n\n{self.task.get_issue_statement()}"
        thread.add_user(prompt)

        return thread

    def _register_reproducing_test(
        self, response: str, test_content: str
    ) -> TestHandle:
        handle = str(self._request_idx)

        assert handle not in self._responses
        assert handle not in self._feedbacks
        assert handle not in self._tests
        assert handle not in self._history

        self._responses[handle] = response
        self._tests[handle] = test_content
        self._history.append(handle)

        return handle

    def _register_non_reproducing_test(
        self, response: str, test_content: str, repro_result: ReproResult
    ) -> TestHandle:
        handle = str(self._request_idx)

        assert handle not in self._responses
        assert handle not in self._feedbacks
        assert handle not in self._tests
        assert handle not in self._non_repro_history

        self._responses[handle] = response
        self._tests[handle] = test_content
        self._non_repro_history.append(handle)
        self._feedbacks[handle].append(self._feedback_from_repro_result(repro_result))

        return handle

    def _feedback_from_repro_result(self, repro_result: ReproResult) -> str:
        return (
            "This test did not reproduce the issue.\n"
            "\n"
            f"The test execution exited with code {repro_result.returncode}.\n"
            "\n"
            f"Standard output: {repro_result.stdout}\n"
            "\n"
            f"Standard error: {repro_result.stderr}"
        )

    @classmethod
    def convert_response_to_test(cls, response: str) -> str | None:
        blocks = extract_markdown_code_blocks(response)

        if len(blocks) == 1:
            return blocks[0]
        elif len(blocks) == 2 and blocks[1].strip() == "python3 reproducer.py":
            return blocks[0]
        else:
            return None

    def save_test(self, handle: TestHandle) -> None:
        Path(self.task_dir, f"reproducer_{handle}.py").write_text(self._tests[handle])


    
    def select_test_cases(self, test_num = config.test_num, test_case_selection_mode = config.test_case_selection_mode):
        task = self.task
        if len(task.selected_test_list) > 0: # already selected in context retrieval.
            self.selected_test_cases_list = task.selected_test_list
            return self.selected_test_cases_list
        
        test_data = self.task.test_data
        combined_tests = []
        for call_dis, test_list in test_data.items():
            combined_tests.extend(test_list)
        
        test_num = test_num if test_num < len(combined_tests) else len(combined_tests) # num of test we execute
        if test_case_selection_mode == 'all': # select all test
            self.selected_test_cases_list = combined_tests
            
        elif test_case_selection_mode == "call_distance":
            self.selected_test_cases_list = combined_tests[:test_num]
            
        elif test_case_selection_mode == 'random':
            self.selected_test_cases_list = random.sample(combined_tests, test_num)
            
        elif test_case_selection_mode == "fail_signal":
            # classify and pritoeize tests that yield exception or assertion messages
            all_fail_signal_tests = []
            for test in combined_tests:
                if test['fail_signal']:
                    all_fail_signal_tests.append(test)
            if test_num >= len(all_fail_signal_tests):
                self.selected_test_cases_list = all_fail_signal_tests
            else:
                self.selected_test_cases_list = random.sample(all_fail_signal_tests, test_num)
                
        elif test_case_selection_mode == 'complexity': 
            # choose tests with least complexity.
            complexity_tests = []
            for test in combined_tests:
                if test['complexity'] >= 0:
                    complexity_tests.append(test)
            complexity_tests_sorted = sorted(complexity_tests, key=lambda x: x["complexity"])
            if test_num >= len(complexity_tests_sorted):
                self.selected_test_cases_list = complexity_tests_sorted
            else:
                self.selected_test_cases_list = complexity_tests_sorted[:test_num]
        
        elif test_case_selection_mode == "diff_call":
            # select tests with diff direct call
            seen_call_infos = set()
            self.selected_test_cases_list = []
            remaining_tests = []
            for test in combined_tests:
                direct_call = test.get("direct_call_info", {})
                call_key = (direct_call.get("file"), direct_call.get("line"))
                if call_key not in seen_call_infos:
                    seen_call_infos.add(call_key)
                    self.selected_test_cases_list.append(test)
                else:
                    remaining_tests.append(test)
                if len(self.selected_test_cases_list) >= test_num:
                    break
            if len(self.selected_test_cases_list) < test_num:
                for test in remaining_tests:
                    self.selected_test_cases_list.append(test)
                    if len(self.selected_test_cases_list) >= test_num:
                        break
           
        task.selected_test_list = self.selected_test_cases_list
        return self.selected_test_cases_list
    
    
    # functions newly added by 
    def execute_command_docker(self, cmd_str, container, workdir, timeout = 300):
        class ExecResult:
            def __init__(self, exit_code, output):
                self.exit_code = exit_code
                self.output = output  # Tuple (stdout_bytes, stderr_bytes)
        # print(f'executing command {cmd_str} in path {workdir}')
        container_name = self.task.container_name
        full_cmd = f"docker exec -w {shlex.quote(workdir)} {shlex.quote(container_name)} bash -c {shlex.quote(cmd_str)}"
        try:
            result = subprocess.run(
                shlex.split(full_cmd),
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                timeout=timeout
            )
            # Match expected structure
            return ExecResult(
                exit_code=result.returncode,
                output=(result.stdout, result.stderr)
            )
        
        except subprocess.TimeoutExpired as e:
            raise RuntimeError(f"[Timeout] Command timed out after {timeout} seconds:\n{cmd_str}") from e

        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"[CalledProcessError] Command failed:\n{cmd_str}\nReturn code: {e.returncode}") from e

        except Exception as e:
            raise RuntimeError(f"[Exception] Unexpected error while executing command:\n{cmd_str}\n{e}") from e
        # old version
        # try:
        #     execute_result = container.exec_run(
        #             cmd = cmd_str,
        #             stdout = True,
        #             stderr = True,
        #             demux = True, # returns (stdout, stderr)
        #             workdir = workdir, # fixed workdir for now
        #             tty = False
        #         )
        #     # stdout, stderr = execute_result.output  # UNPACK here
        #     # if stdout:
        #     #     stdout_cleaned = clean_ansi(stdout.decode())
        #     # if stderr:
        #     #     stderr_cleaned = clean_ansi(stderr.decode())
        #     return execute_result
        # except Exception as e:
        #     print(f'Exception happened in execute_command_docker when running command {cmd_str}: {e}')
        #     raise e
        
    
    
    # functions newly added by 
    def get_raw_pytest_output(self, patch_content: str):
        
        test_cases_output_list = []
        
        if len(self.selected_test_cases_list) == 0:
            print('No test function selected. Call select_test_cases first.')
            self.selected_test_cases_list = self.select_test_cases()
        task = self.task
        # build container
        container = build_container(task.container_name)
        # apply the patch into the local project and copy into container, and recover the local
        workdir=os.path.join(config.REPOCOD_WORK_DIR, task.project_name) 
        
        if task.project_name == 'plotly.py':
            workdir = os.path.join(workdir, "packages/python/plotly")
        
        with NamedTemporaryFile(buffering=0, suffix=".diff") as f:
            f.write(patch_content.encode())
            temp_filename = f.name  # Save the filename before closing
            apply_cmd = ["git", "apply", temp_filename]
            # apply to local project
            subprocess.run(
                    apply_cmd,
                    capture_output=True,
                    text=True,
                    cwd=task.project_path,
                    check=True,)
        # copy to container
        src_file_path = self.buglocs_list[0].abs_file_path
        des_docker_path = os.path.join(config.REPOCOD_WORK_DIR, task.project_name, self.buglocs_list[0].rel_file_path)
        copy_file_to_docker(container, src_file_path, des_docker_path)
        # recover the local project
        with apputils.cd(self.task.local_repo):
            apputils.repo_clean_changes()
            
        # execute test cases in docker
        repo_specific_command = ""
        if "xarray" in task.container_name:
            repo_specific_command = "/root/miniconda3/bin/conda run -n xarray-tests "
        for task_dict in self.selected_test_cases_list:
            test_path_docker = task_dict['base_nodeid']
            cmd = f"{repo_specific_command}pytest {test_path_docker} --tb=short -v"
            try:
                exec_result = self.execute_command_docker(cmd, container, workdir)
                exit_code = exec_result.exit_code # 0: pass, other: error
                stdout, stderr = exec_result.output
                stdout = clean_ansi(stdout.decode("utf-8")) if stdout else ""
                stderr = clean_ansi(stderr.decode("utf-8")) if stderr else ""
            except Exception as e:
                exit_code = -1
                stdout = f'Exception happened when executing {test_path_docker} in get_raw_pytest_output: {e}'
                stderr = f'Exception happened when executing {test_path_docker} in get_raw_pytest_output: {e}'
                print(stdout)
                exit()
            test_result = {
                'exit_code': exit_code,
                "stdout": stdout,
                "stderr": stderr,
            }
            test_cases_output_list.append(test_result)
        
        # recover the docker container
        reset_cmd = "git reset --hard"
        clean_cmd = "git clean -fd"
        self.execute_command_docker(reset_cmd, container, workdir)
        self.execute_command_docker(clean_cmd, container, workdir)
        
        self.test_cases_output_list = test_cases_output_list
        return self.test_cases_output_list
            
    
    # functions newly added by 
    def build_test_parsed_feedback_list(self):
        test_parsed_feedback_list = []
        def extract_failure_block(output: str) -> Optional[str]:
            # Try to extract standard FAILURES block
            match = re.search(
                r"=+ FAILURES =+\n(.*?)=+ short test summary info =+",
                output,
                re.DOTALL
            )
            if match:
                return match.group(1).strip() 
            # Try to extract ERRORS block
            match = re.search(
                r"=+ ERRORS =+\n(.*?)=+ short test summary info =+",
                output,
                re.DOTALL
            )
            if match:
                return match.group(1).strip()
            # Still no clue? Return last 40 lines as a general fallback
            lines = output.strip().splitlines()
            return "\n".join(lines[-40:]) if len(lines) >= 40 else output.strip()
        
        if len(self.test_cases_output_list) == 0:
            print('No test function selected. Call get_pytest_output first.')
            self.selected_test_cases_list = self.get_raw_pytest_output()
            
        ALL_PASS: bool = True
        
        for idx, test_result in enumerate(self.test_cases_output_list):
            # test_result = {
            #     'exit_code': exit_code,
            #     "stdout": stdout,
            #     "stderr": stderr,
            # }
            stdout = test_result["stdout"]
            stderr = test_result["stderr"]
            output_str = '\n'.join([stdout, stderr])
            match = extract_failure_block(output_str)
            if match:
                output_str = match
                
            if(test_result['exit_code'] == 0): # passed cases
                test_case_local_path = self.selected_test_cases_list[int(idx)]['base_nodeid']
                test_src_code = self.selected_test_cases_list[int(idx)]['src_code']
                # test_feedback = f"PASSED! We picked the test case:`{test_case_local_path}`.\nThe source code of the test:\n```\n{test_src_code}\n```\nOur generated code passed successfully!\n\n"
                test_feedback = f"PASSED! We picked the test case:`{test_case_local_path}`.\nSource Code:\n```\n{test_src_code}\n```\nOur generated code passed successfully!\n\n"
            else: # failed cases
                ALL_PASS = False
                test_case_local_path = self.selected_test_cases_list[int(idx)]['base_nodeid']
                test_src_code = self.selected_test_cases_list[int(idx)]['src_code']
                # test_feedback = f"FAILED! We picked the test case: `{test_case_local_path}`\nThe source code of the test:\n```\n{test_src_code}\n```\n Our generated code **failed** with the following error:\n```\n{output_str}\n```\n\n"
                test_feedback = f"FAILED! We picked the test case: `{test_case_local_path}`\nSource Code:\n```\n{test_src_code}\n```\nOur generated code **failed** with the following error:\n```\n{output_str}\n```\n\n"
            # print_acr(
            #     msg = test_feedback,
            #     desc = "Test Agent"
            # )
            test_parsed_feedback_list.append(test_feedback)
            
        self.test_parsed_feedback_list = test_parsed_feedback_list
        return self.test_parsed_feedback_list, ALL_PASS
        
        
    # function newly added by      
    def save_tests_rawoutputs_parsedfeedback(self, round_num):
        # make sure the following attributes are non-empty:
        # self.selected_test_cases_list: list[dict] = []
        # self.test_cases_output_list: list[dict] = []
        # self.test_parsed_feedback_list: list[str] = []
        combined_list = []
        for selected_test, raw_output, parsed_feedback in zip(
            self.selected_test_cases_list, 
            self.test_cases_output_list, 
            self.test_parsed_feedback_list):
            # selected_test: test_path_docker: str, test_path_local: str, src_code: str
            # raw_output: exit_code: int, stdout: str, stderr: str
            # parsed_feedback: str
            element_dict = {
                "selected_test": selected_test,
                "raw_output": raw_output,
                "parsed_feedback": parsed_feedback
            }
            combined_list.append(element_dict)
        save_path = Path(self.task_dir, f"selected_tests4refine_{round_num}.json")
        with open(save_path, 'w') as f:
            json.dump(combined_list, f, indent=2)
        return save_path
        
                
            
                
        
    
    
    
def generator(
    issue_statement: str,
) -> Generator[tuple[str, MessageThread, bool], str | None, None]:
    prefix_thread = MessageThread()
    prefix_thread.add_system(SYSTEM_PROMPT)

    prompt = f"Here is an issue:\n\n{issue_statement}"
    prefix_thread.add_user(prompt)
    # print_acr(prompt, "reproducer test generation")

    prefix_thread.add_user(INITIAL_REQUEST)
    print_acr(INITIAL_REQUEST, "reproducer test generation")

    threads = []

    index = 1
    thread = deepcopy(prefix_thread)
    while True:
        response, *_ = common.SELECTED_MODEL.call(prefix_thread.to_msg())

        thread.add_model(response, [])
        print_reproducer(response, desc=f"Try {index}")

        index += 1

        threads.append(thread)

        code_blocks = extract_markdown_code_blocks(response)

        if len(code_blocks) != 1:
            _ = yield "", thread, False

            new_prompt = (
                f"Expected 1 code block, got {len(code_blocks)}. Please try again."
            )
        else:
            test_content = code_blocks[0]
            evaluation_msg = yield test_content, thread, True

            assert evaluation_msg is not None

            new_prompt = f"The issue reproduction is incorrect. {evaluation_msg} Please try again."

        thread.add_user(new_prompt)


def extract_markdown_code_blocks(content: str) -> list[str]:
    lines = content.splitlines(keepends=True)

    in_code_block = False
    start_pattern = r"\s*```\w*\s*"
    end_pattern = r"\s*```\s*"

    start, end = -1, -1
    intervals = []

    for idx, line in enumerate(lines):
        if (not in_code_block) and re.match(start_pattern, line):
            in_code_block = True
            start = idx + 1
        elif in_code_block and re.match(end_pattern, line):
            in_code_block = False
            end = idx
            intervals.append((start, end))

    res = ["".join(lines[start:end]) for start, end in intervals]
    return res
