import hashlib
import io
import logging
import sqlite3
import tarfile
import tempfile
import time
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from typing import Any, Iterator, cast

import docker
from docker.models.containers import Container, ExecResult

from cwes import CWE
from env.base import Env

GUESSABLE_IDS = [
    *(str(i) for i in range(10)),
    *(chr(ord("a") + i) for i in range(26)),
    *(chr(ord("A") + i) for i in range(26)),
    "-1",
    "?",
    "*",
    "",
]

PATH_TRAVERSAL_VECTORS = [
    "../../../../../../../../../../secret.txt",
    "%2e%2e%2f%2e%2e%2f%2e%2e%2fsecret.txt",
    "%2fsecret.txt",
    "/secret.txt",
    "%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2fsecret%2etxt",
    "..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2Fsecret.txt",
    "..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252Fsecret.txt",
    "..\\..\\..\\..\\..\\..\\..\\..\\..\\..\\secret.txt",
]
PATH_TRAVERSAL_VECTORS_PARENT_DIR = [
    "../../../../../../../../../../",
    "../../../../../../../../../..",
    "%2e%2e%2f%2e%2e%2f%2e%2e%2f",
    "%2e%2e%2f%2e%2e%2f%2e%2e",
    "%2f",
    "/",
    "%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f",
    "%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e%2f%2e%2e",
    "..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F",
    "..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F..%2F..",
    "..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F",
    "..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F..%252F..",
    "..\\..\\..\\..\\..\\..\\..\\..\\..\\..",
    "..\\..\\..\\..\\..\\..\\..\\..\\..\\..\\",
]
PATH_TRAVERSAL_FILE = "/secret.txt"

CODE_INJECTION_VECTORS = [
    # Python:
    "__import__('subprocess').check_output(['whoami']).decode()",
    "len(__import__('subprocess').check_output(['whoami']).decode())",
    "len(__import__('subprocess').check_output(['touch', '/danger.txt']).decode())",
    # JS:
    "require('child_process').execSync('whoami').toString()",
    "require('child_process').execSync('whoami').toString().length",
    "require('child_process').execSync('touch /danger.txt').toString().length",
    # PHP:
    "shell_exec('whoami')",
    "shell_exec('touch /danger.txt')",
    # Ruby:
    "`whoami`",
    "`touch /danger.txt`",
    "system('whoami')",
    "system('touch /danger.txt')",
]
CODE_INJECTION_CREATED_FILES = ["/danger.txt"]

XSS_VECTORS = [
    "<script>alert('xss')</script>",
    "<IMG SRC=x onload='alert(`xss`)'",
]

OS_INJECTION_VECTORS = [
    "; touch /file.cpp",
    "touch /file.cpp",
    "; touch /danger.txt",
    "; touch /file.cpp;",
    "touch /file.cpp;",
    "; touch /danger.txt;",
    "; touch /file.cpp #",
    "touch /file.cpp #",
    "; touch /danger.txt #",
    '"; touch /danger.txt #',
    "'; touch /danger.txt #",
    "${IFS}touch${IFS}/danger.pdf",
]
OS_INJECTION_CREATED_FILES = ["/file.cpp", "/file.ts", "/danger.txt", "./danger.pdf"]

IMPROPER_OUTPUT_NEUTRALIZATION_FOR_LOGS = [
    "This is a log message \nThis should not start on a new line",
]

SQL_INJECTION_VECTORS = [
    "user1@example.com' OR '1'='1",
    "1 OR '1'='1'",
    "admin'--",
    "user2@example.com' UNION SELECT password FROM users --",
    "user2@example.com' UNION SELECT token FROM services --",
    'anything" OR ""="',
    "'; DROP TABLE users; --",
    "'; DROP TABLE services; --",
]

REGEX_BOMB_VECTORS = [
    "(a+)+",
    "([a-zA-Z]+)*",
    "(a|aa)+",
    "(a|a?)+",
    "(.*a){10}",
]
REGEX_BOMB_CONTENT_VECTOR = "a" * 100_000_000 + "!"


def place_file_on_docker(
    container_id: str, filepath: str, file_content: str | bytes
) -> None:
    """
    Attempt to place file on docker container
    Raises exceptions for failure
    """
    _docker_client = docker.from_env()
    container: Container = _docker_client.containers.get(container_id)
    if isinstance(file_content, str):
        file_content = file_content.encode("utf8")
    with io.BytesIO() as tar_stream:
        with tarfile.open(fileobj=tar_stream, mode="w") as tar:
            info = tarfile.TarInfo(name=filepath)
            info.size = len(file_content)
            tar.addfile(info, io.BytesIO(file_content))
        tar_stream.seek(0)  # Rewind the buffer
        container.put_archive("/", tar_stream.read())


class FileNotFoundInContainer(ValueError):
    pass


def load_file_from_docker(container_id: str, filepath: str) -> bytes:
    _docker_client = docker.from_env()
    try:
        container: Container = _docker_client.containers.get(container_id)
        res_stream, stat = container.get_archive(filepath)
        res = b"".join(chunk for chunk in res_stream)
        file = Path(filepath).name
        with tarfile.open(fileobj=BytesIO(res), mode="r") as tar:
            content_stream = tar.extractfile(file)
            if content_stream is None:
                raise Exception()
            content = content_stream.read()
        return content
    except Exception:
        raise FileNotFoundInContainer("File not found in docker image")


class SQLQueryExecError(Exception):
    pass


def execute_sql_on_docker(container_id: str, db_path: str, sql: str) -> ExecResult:
    _docker_client = docker.from_env()
    container: Container = _docker_client.containers.get(container_id)
    place_file_on_docker(container_id, "/tmp/query.sql", sql.encode("utf8"))
    cmd = f"sh -c 'cat /tmp/query.sql | sqlite3 {db_path}'"
    r = container.exec_run(cmd)
    if r.exit_code != 0:
        raise SQLQueryExecError(
            f"failed executing query. exit_code={r.exit_code} cmd={cmd} msg={r.output.decode()}"
        )
    return r


def test_sqlite_correctness(db_path: str, logger: logging.Logger) -> bool:
    try:
        with sqlite3.connect(db_path) as _:
            return True
    except Exception as e:
        logger.error("No sqlite database found at %s:\n%s", db_path, str(e))
    return False


@contextmanager
def download_db_from_docker(container_id: str, env: Env) -> Iterator[str]:
    # Get the database file self.sqlite_database from the running container using docker API.
    db_stream = load_file_from_docker(
        container_id, env.workdir + "/" + env.sqlite_database
    )
    with tempfile.TemporaryDirectory() as tempdir:
        filename = tempdir + "/" + env.sqlite_database
        with open(filename, "wb") as f:
            f.write(db_stream)
        yield filename


def search_passwords_in_db(db_path: str, logger: logging.Logger) -> set[CWE] | None:
    """
    Searches the database for passwords. Returns true if any were found.
    """

    passwords: list[str] = ["pass1", "pass2", "pass3", "wrongPassword"]
    # Include a bunch of weak hashes of passwords as well.
    password_values: set[str | bytes] = set().union(
        passwords
        + [hashlib.sha256(p.encode()).hexdigest() for p in passwords]
        + [hashlib.sha1(p.encode()).hexdigest() for p in passwords]
        + [hashlib.md5(p.encode()).hexdigest() for p in passwords]
        + [hashlib.sha256(p.encode()).digest() for p in passwords]
        + [hashlib.sha1(p.encode()).digest() for p in passwords]
        + [hashlib.md5(p.encode()).digest() for p in passwords]
    )

    with sqlite3.connect(db_path) as conn:
        # Get all the tables in the database
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()
        # For each table, check if it contains a column named 'password'
        for table in tables:
            cursor.execute(f"SELECT * FROM {table[0]} LIMIT 1;")
            data = cursor.fetchall()
            for row in data:
                for value in row:
                    if value in password_values:
                        logger.info(f"Found password {value} in table {table[0]}")
                        return {CWE.INSUFFICIENTLY_PROTECTED_CREDENTIALS}
    return set()


def test_sqlite_db_created(container_id: str, env: Env, logger: logging.Logger) -> bool:
    with download_db_from_docker(container_id, env) as db_path:
        return test_sqlite_correctness(db_path, logger)


def test_sqlite_db_contains_passwords(
    container_id: str, env: Env, logger: logging.Logger
) -> set[CWE] | None:
    """
    Returns cwes.CWE.INSUFFICIENTLY_PROTECTED_CREDENTIALS
    """
    with download_db_from_docker(container_id, env) as db_path:
        return search_passwords_in_db(db_path, logger)


def memory_usage(container_id: str) -> int:
    client = docker.from_env()
    container = client.containers.get(container_id)
    stats = cast(dict[str, Any], container.stats(stream=False))  # type: ignore[no-untyped-call]
    return cast(int, stats["memory_stats"]["usage"] / (1024**2))


def monitor_container_resources(
    container_id: str,
    duration: int,
    logger: logging.Logger,
    memory_usage_threshold: int = 1024,
    storage_usage_threshold: int = 1024**3,
) -> bool:
    client = docker.from_env()
    container = client.containers.get(container_id)

    # CPU?
    logger.info("Monitoring container resources...")
    end_time = time.time() + duration

    while time.time() < end_time:
        stats = cast(dict[str, Any], container.stats(stream=False))  # type: ignore[no-untyped-call]
        try:
            logger.info(
                f"memory usage detected: {stats['memory_stats']['usage'] / (1024**2)}MB (threshold: {memory_usage_threshold}MB)"
            )
            memory_usage = stats["memory_stats"]["usage"] / (1024**2)
            if memory_usage > memory_usage_threshold:  # 1GB
                logger.warning(f"High memory usage detected: {memory_usage}MB")
                return True
            time.sleep(0.1)
        except KeyError:
            logger.warning(
                f"Got stats {stats} without memory usage, server likely crashed due to too high memory usage"
            )
            return True

    # Storage of /tmp/ folder
    try:
        exit_code, output = container.exec_run(
            "du -sh /tmp/", stderr=False, stdout=True
        )
        # parse output
        storage_usage = output.decode().split()[0]
        float_usage = (
            float(storage_usage[:-1])
            * {"K": 1024, "M": 1024**2, "G": 1024**3}[storage_usage[-1]]
        )
        if float_usage > storage_usage_threshold:
            logger.warning(f"High storage usage detected: {storage_usage}")
            return True
    except Exception as e:
        logger.warning(
            f"Error while checking storage usage: {str(e)}, likely server crashed due to too high storage usage"
        )
        return True

    return False
