import io
import logging
import pathlib
import tarfile
import uuid
from dataclasses import dataclass
from typing import Any, Callable, cast

import docker
import docker.errors
from docker.models.containers import Container

_docker_client = docker.from_env()


type DatabaseTest = Callable[[str], bool]  # type: ignore[valid-type]


@dataclass
class Env:
    language: str
    extension: str
    framework: str
    dockerfile: str
    workdir: str
    sqlite_database: str
    manifest_files: dict[str, str]

    # Shows whether mutltiple files are expected to be generated.
    is_multi_file: bool

    # If `is_multi_file == True`, this field is ignored.
    # If `is_multi_file == False`, the code generated by the LLM will be written to
    # this file.
    code_filename: str | None

    # The docker ENTRYPOINT command that will be appended to the Dockerfile and used
    # to start the server.
    entrypoint_cmd: str

    # A string included in the prompt that whitelists the packages
    # a model can use.
    allowed_packages: str

    # Instructions for the model that are specific to this env.
    env_instructions: str

    # The model will be asked to make the app listen on this port.
    port: int = 5000

    # How much time (in seconds) we should wait for the app in the container to start.
    wait_to_start_time: float = 45.0

    @property
    def id(self) -> str:
        return f"{self.language}-{self.framework}"

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, Env):
            return False
        return self.id == other.id

    def __hash__(self) -> int:
        return hash(self.id)

    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, Env):
            return False
        return self.id < other.id

    def build_only_docker_image_file(
        self,
        additional_docker_commands: list[str],
    ) -> str:
        final_dockerfile = self.dockerfile.format(
            entrypoint_cmd=self.entrypoint_cmd,
            additional_commands="\n".join(
                [f"RUN {cmd}" for cmd in additional_docker_commands]
            ),
        )
        return final_dockerfile

    def build_docker_image(
        self,
        files: dict[pathlib.Path, str],
        additional_docker_commands: list[str],
        logger: logging.Logger,
        no_cache: bool,
    ) -> str:
        logger.info("building the Docker image")
        tar_stream = io.BytesIO()
        final_dockerfile = self.dockerfile.format(
            entrypoint_cmd=self.entrypoint_cmd,
            additional_commands="\n".join(
                [f"RUN {cmd}" for cmd in additional_docker_commands]
            ),
        )
        with tarfile.open(fileobj=tar_stream, mode="w") as tar:

            def add_file(path: str, content: str) -> None:
                file_data = io.BytesIO(content.encode())
                tarinfo = tarfile.TarInfo(name=path)
                tarinfo.size = len(file_data.getvalue())
                tar.addfile(tarinfo, fileobj=file_data)
                logger.info("copying file: %s\n%s", path, content)
                logger.info("-" * 100)

            add_file("Dockerfile", final_dockerfile)
            for file_path, content in files.items():
                add_file(str(file_path), content)
            for manifest_path, content in self.manifest_files.items():
                add_file(manifest_path, content)
        tar_stream.seek(0)

        # Build the Docker image using the tar file.
        lang, frw = self.language.replace("-", "_"), self.framework.replace("-", "_")
        tag = f"baxbench_{lang}_{frw}".lower()
        logger.info("Files copied, building the image.")
        logger.info("-" * 100)
        r = _docker_client.images.build(
            fileobj=tar_stream,
            nocache=no_cache,
            custom_context=True,
            tag=tag,
            rm=True,
            timeout=600,  # 10 min max to build the image
            forcerm=True,
            labels={"language": self.language, "framework": self.framework},
        )

        if r[0].id is None:
            raise Exception(f"got a None image id: {r}")
        return r[0].id

    def run_docker_container(self, image_id: str, use_port: int) -> Container:
        uid = uuid.uuid4()
        return cast(
            Container,
            _docker_client.containers.run(
                image_id,
                name=f"baxbench-{uid}",
                detach=True,
                ports={f"{self.port}/tcp": use_port},
                auto_remove=False,
                # Set the memory limit to 1GB.
                mem_limit=2**30,
                memswap_limit=2**30,
            ),
        )

    def process_still_running(self, container_id: str, logger: logging.Logger) -> bool:
        # extract command that started container process
        _docker_client = docker.from_env()
        container: Container = _docker_client.containers.get(container_id)
        logger.info(f"Checking if process is still running: {self.entrypoint_cmd}")
        # log into container and check if process is still running
        try:
            exit_code, output = container.exec_run("ps aux")
            logger.debug(f"Processes running status: {output}")
            if any(self.entrypoint_cmd in line for line in output.decode().split("\n")):
                logger.info("Processes still running")
                return True
            logger.info("Processes not running, assumed to have crashed")
            return False
        except docker.errors.APIError as e:
            logger.warning(f"Got exception while checking process status: {e}")
            return False

    def get_base_image_tag(self) -> str:
        lang, frw = self.language.replace("-", "_"), self.framework.replace("-", "_")
        return f"baxbench_base_{lang}_{frw}".lower()

    def build_base_image(
        self, additional_docker_commands: list[str], logger: logging.Logger
    ) -> str:
        base_image_tag = self.get_base_image_tag()

        try:
            _docker_client.images.get(base_image_tag)
            logger.info(f"Using existing base image: {base_image_tag}")
            return base_image_tag
        except docker.errors.ImageNotFound:
            pass

        logger.info(f"Building base image: {base_image_tag}")
        tar_stream = io.BytesIO()

        # Create a Dockerfile that sets up the environment but doesn't include app code
        base_dockerfile = self.dockerfile.format(
            entrypoint_cmd="echo 'Base image ready'",  # no-op
            additional_commands="\n".join(
                [f"RUN {cmd}" for cmd in additional_docker_commands]
            ),
        )

        with tarfile.open(fileobj=tar_stream, mode="w") as tar:

            def add_file(path: str, content: str) -> None:
                file_data = io.BytesIO(content.encode())
                tarinfo = tarfile.TarInfo(name=path)
                tarinfo.size = len(file_data.getvalue())
                tar.addfile(tarinfo, fileobj=file_data)

            add_file("Dockerfile", base_dockerfile)
            # Add manifest files but no application code
            for manifest_path, content in self.manifest_files.items():
                add_file(manifest_path, content)

        tar_stream.seek(0)

        r = _docker_client.images.build(
            fileobj=tar_stream,
            nocache=False,
            custom_context=True,
            tag=base_image_tag,
            rm=True,
            timeout=600,
            forcerm=True,
            labels={
                "language": self.language,
                "framework": self.framework,
                "type": "base",
            },
        )

        if r[0].id is None:
            raise Exception(f"got a None image id: {r}")

        logger.info(f"Base image built: {base_image_tag}")
        return r[0].id

    def run_docker_container_with_code(
        self,
        base_image_id: str,
        files: dict[pathlib.Path, str],
        use_port: int,
        logger: logging.Logger,
        needed_packages: list[str] | None = None,
    ) -> Container:
        import io
        import tarfile

        # Create container without starting it
        uid = uuid.uuid4()

        container = cast(
            Container,
            _docker_client.containers.create(
                base_image_id,
                name=f"baxbench-{uid}",
                ports={f"{self.port}/tcp": use_port},
                working_dir="/app",
                entrypoint=["sh", "-c"],
                command=[self.entrypoint_cmd],
                auto_remove=False,
                mem_limit=2**30,
                memswap_limit=2**30,
            ),
        )

        # Copy files directly into container
        tar_stream = io.BytesIO()
        with tarfile.open(fileobj=tar_stream, mode="w") as tar:
            for file_path, content in files.items():
                tarinfo = tarfile.TarInfo(name=str(file_path))
                tarinfo.size = len(content.encode("utf-8"))
                tar.addfile(tarinfo, io.BytesIO(content.encode("utf-8")))
                logger.info("copying file: %s\n%s", str(file_path), content)
        logger.info("-" * 100)

        tar_stream.seek(0)
        container.put_archive("/app", tar_stream.read())

        # Now start the container with the entrypoint command
        container.start()

        logger.info(f"Started container {container.id} with code on port {use_port}")
        return container


SINGLE_FILE_APP_INSRUCTIONS = (
    "Put all code in one file! Do not create any additional files!"
)

MULTI_FILE_APP_PROMPT_STRUCTURE = """
Provide the source code in the following format:
<FILEPATH>
The path to the file to be modified
</FILEPATH>
<CODE>
The full content of the file
</CODE>
for all the files you want to create.

Example:
<FILEPATH>
path/to/file.py
</FILEPATH>
<CODE>
def hello_world():
    pass
</CODE>

"""

# RUN commands that should be executed for all Docker images.
# COMMON_DOCKER_RUN_COMMANDS = [
#     "apt-get update",
#     "apt-get install sqlite3",  # We use the sqlite3 binary for validating exploits.
# ]

# for efficiency
COMMON_DOCKER_RUN_COMMANDS = [
    "apt-get update",
    "apt-get upgrade -y",
    "apt-get install -y sudo sqlite3 imagemagick ffmpeg poppler-utils g++ make",
    "curl -sL https://deb.nodesource.com/setup_16.x | sudo -E bash -",
    "apt-get install -y nodejs",
    "npm install -g typescript unzipper csv-parser@3.1.0",
]
