import contextlib
import json
import math
import sqlite3
from logging import Logger
from pathlib import Path
from typing import List, Tuple, Union, Callable

import pandas as pd
from pandas import DataFrame
from sqlalchemy import create_engine

from compute_result.result_store.base import BUDGET_COL, POINT_COL
from compute_result.typing import ProblemSpace, Point, RunResult, Run
from problems.types import Suites

TABLE_NAME_BREAK = "_-_"


@contextlib.contextmanager
def get_sqlalchemy_conn(db_path: Path):
    # creates a database
    engine = create_engine(f"sqlite:///{db_path}", echo=False)
    conn = engine.connect()
    yield conn
    conn.close()


@contextlib.contextmanager
def get_conn(db_path: Path):
    conn = sqlite3.connect(db_path)
    yield conn
    conn.commit()
    conn.close()


def from_problem_space_to_db_readable(problem_space: ProblemSpace):
    return problem_space[0].value, *problem_space[1:]


@contextlib.contextmanager
def get_cursor(db_path: Path):
    with get_conn(db_path) as conn:
        yield conn.cursor()


def from_db_to_problem(suite: str, func_id: int, dim: int, instance: int) -> ProblemSpace:
    return Suites(suite), int(func_id), int(dim), int(instance)


def df_to_tuple_list(df: pd.DataFrame) -> List[Tuple]:
    return [
        (rec[1].iloc[0], rec[1].iloc[1], rec[1].iloc[2], rec[1].iloc[3])
        for rec in df.iterrows()
    ][::-1]


def create_min_max_table(cursor):
    cursor.execute(
        f"""CREATE TABLE IF NOT EXISTS min_max
                         (min_value REAL, min_point TEXT, max_value REAL, max_point TEXT, suite TEXT, id TEXT, dim TEXT, instance TEXT)"""
    )


def list_run_result(db_path: Path, table_name: str):
    with get_sqlalchemy_conn(db_path) as conn:
        result_df = pd.read_sql_table(table_name, conn)
        result_df = result_df.sort_values(BUDGET_COL, ascending=False)
        result_df = result_df.drop("index", axis=1)
        result_df[POINT_COL] = result_df[POINT_COL].apply(json.loads)

    return df_to_tuple_list(result_df)


def last_record_from_run_result(db_path: Path, table_name: str):
    with get_sqlalchemy_conn(db_path) as conn:
        result_df = pd.read_sql_query(
            f"SELECT * FROM `{table_name}` ORDER BY budget DESC LIMIT 1", conn
        )
        result_df = result_df.drop("index", axis=1)
        result_df[POINT_COL] = result_df[POINT_COL].apply(json.loads)
    return df_to_tuple_list(result_df)[0]


def is_table_exist(db_path: Path, table_name: str):
    with get_cursor(db_path) as c:
        # Check if the table exists
        c.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'")
        result = c.fetchone()

        # If a result was returned, the table exists
        if result is not None:
            return True
        else:
            return False


def min_max_from_space(
    db_path: Path, problem_space: ProblemSpace
) -> Tuple[float, float, Point, Point]:
    with get_cursor(db_path) as cursor:
        create_min_max_table(cursor)
        # Query the min_max table for the given problem
        cursor.execute(
            "SELECT min_value, max_value, min_point, max_point FROM min_max "
            "WHERE suite = ? AND id = ? AND dim = ? AND instance = ?",
            from_problem_space_to_db_readable(problem_space),
        )

        result = cursor.fetchone()
        if result is None:
            return math.inf, -math.inf, [], []
        min_value, max_value, min_point, max_point = result
    return min_value, max_value, json.loads(min_point), json.loads(max_point)


def update_min(
    db_path: Path,
    logger: Logger,
    problem_space: ProblemSpace,
    min_value: float,
    min_point: Point,
):
    if not db_path.exists():
        db_path.parent.mkdir(parents=True, exist_ok=True)
        db_path.touch()

    logger.info(f"Updating min in {problem_space} for {min_value}")
    suite, func_id, dim, instance = problem_space

    with get_cursor(db_path) as cursor:
        create_min_max_table(cursor)

        current_min_query = (
            f"SELECT min_value FROM min_max WHERE instance={instance} AND "
            f"suite='{suite.value}' AND id={func_id} AND dim={dim}"
        )
        cursor.execute(current_min_query)
        current_min = cursor.fetchone()

        if current_min is None:
            logger.info(f"No point found, adding: {min_value}, {min_point}")
            cursor.execute(
                "INSERT INTO min_max VALUES (?, ?, ?, ?, ?, ?, ? ,?)",
                (
                    min_value,
                    json.dumps(min_point),
                    -math.inf,
                    "[]",
                    *from_problem_space_to_db_readable(problem_space),
                ),
            )
        elif min_value < current_min[0]:
            logger.info(f"New min found {min_value}, {min_point}")
            # Update the min_value and min_point for the given problem
            cursor.execute(
                "UPDATE min_max SET min_value = ?, min_point = ? "
                "WHERE suite = ? AND id = ? AND dim = ? AND instance = ?",
                (
                    min_value,
                    json.dumps(min_point),
                    *from_problem_space_to_db_readable(problem_space),
                ),
            )


def update_max(
    db_path: Path,
    logger: Logger,
    problem_space: ProblemSpace,
    max_value: float,
    max_point: Point,
):
    if not db_path.exists():
        db_path.parent.mkdir(parents=True, exist_ok=True)
        db_path.touch()

    suite, func_id, dim, instance = problem_space

    with get_cursor(db_path) as cursor:
        create_min_max_table(cursor)

        current_max_query = (
            f"SELECT max_value FROM min_max WHERE instance={instance} AND "
            f"suite='{suite.value}' AND id={func_id} AND dim={dim}"
        )
        cursor.execute(current_max_query)
        current_max = cursor.fetchone()

        if current_max is None:
            logger.info(f"No point found, adding: {max_value}, {max_point}")
            cursor.execute(
                "INSERT INTO min_max VALUES (?, ?, ?, ?, ?, ?, ? ,?)",
                (
                    math.inf,
                    "[]",
                    max_value,
                    json.dumps(max_point),
                    *from_problem_space_to_db_readable(problem_space),
                ),
            )
        elif max_value > current_max[0]:
            logger.info(f"New max found {max_value}, {max_point}")
            # Update the max_value and max_point for the given problem
            cursor.execute(
                "UPDATE min_max SET max_value = ?, max_point = ? "
                "WHERE suite = ? AND id = ? AND dim = ? AND instance = ?",
                (
                    max_value,
                    json.dumps(max_point),
                    *from_problem_space_to_db_readable(problem_space),
                ),
            )


def store_step(
    db_path: Path,
    table_name: str,
    budget: int,
    value: float,
    point: Point,
    algorithm_name: str,
):
    if not db_path.exists():
        db_path.parent.mkdir(parents=True, exist_ok=True)
        db_path.touch()

    with get_cursor(db_path) as cursor:
        # Create a table named "run" with columns "value", "budget", "point", and "algorithm_name"
        cursor.execute(
            f"""CREATE TABLE IF NOT EXISTS {table_name}
                             (value REAL, budget REAL, point REAL, algorithm_name TEXT)"""
        )

        # Insert a row into the "run" table with the specified values
        cursor.execute(
            f"INSERT INTO {table_name} VALUES (?, ?, ?, ?)",
            (value, budget, json.dumps(point), algorithm_name),
        )


def store_run(
    db_path: Path,
    table_name: str,
    run_result_df: Union[RunResult, pd.DataFrame],
):
    if not db_path.exists():
        db_path.parent.mkdir(parents=True, exist_ok=True)
        db_path.touch()

    with get_conn(db_path) as conn:
        run_result_df.to_sql(table_name, conn, if_exists="replace")


def list_min_max_problems(db_path: Path) -> List[ProblemSpace]:
    with get_cursor(db_path) as cursor:
        cursor.execute("SELECT suite, id, dim, instance FROM min_max")
        rows = cursor.fetchall()
        return [from_db_to_problem(*row) for row in rows]


def list_tables(db_path: Path) -> List[str]:
    with get_cursor(db_path) as cursor:
        # Get a list of all the tables in the database
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = cursor.fetchall()

        # Filter the list to only include tables corresponding to runs
        run_tables = [table[0] for table in tables if "_" in table[0]]

        return [table_name for table_name in run_tables]


def remove_table(db_path: Path, logger: Logger, table_name: str):
    with get_cursor(db_path) as cursor:
        # Check if the table exists
        cursor.execute(
            f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
        )
        result = cursor.fetchone()

        # If the table exists, drop it
        if result:
            cursor.execute(f"DROP TABLE `{table_name}`")
            cursor.execute(f"DROP INDEX IF EXISTS `ix_{table_name}_index`;")
            logger.info(f"Dropped table '{table_name}'")


def find_algorithms_in_table(
    db_path: Path, run: Run, table_name_analyzer: Callable[[str], Tuple[Run, ProblemSpace]]
):
    algorithm_name, run_name = run
    with get_cursor(db_path) as cursor:
        # Get a list of all the tables in the database that correspond to the given algorithm name and run name
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = cursor.fetchall()
        tables = [
            table[0]
            for table in tables
            if (curr_run := table_name_analyzer(table[0]))
            if curr_run[0][0] == algorithm_name and run_name == curr_run[0][1]
        ]

        # Initialize a set to hold the unique algorithm names
        unique_algorithm_names = set()

        # Iterate over the tables and extract the unique algorithm names from each one
        for table in tables:
            cursor.execute(f"SELECT algorithm_name FROM `{table}`")
            rows = cursor.fetchall()
            unique_algorithm_names.update(row[-1] for row in rows)

    return list(unique_algorithm_names)


def store_df(df: DataFrame, db_path: Path, table_name: str):
    with get_conn(db_path) as conn:
        df.to_sql(table_name, conn, if_exists="replace")


def extract_df(db_path: Path, table_name: str):
    with get_sqlalchemy_conn(db_path) as conn:
        return pd.read_sql_table(table_name, conn)
    return df_to_tuple_list(result_df)
