import contextlib
import logging
from logging import Logger
from pathlib import Path
from typing import List, Tuple

from pandas import DataFrame

from compute_result.result_store.base import ResultStore
from compute_result.result_store.sqlite_utils import (
    get_sqlalchemy_conn,
    get_conn,
    get_cursor,
    list_run_result,
    is_table_exist,
    min_max_from_space,
    update_min,
    update_max,
    list_tables,
    list_min_max_problems,
    store_run,
    store_step,
    TABLE_NAME_BREAK,
    remove_table,
    find_algorithms_in_table,
    last_record_from_run_result,
    store_df,
    extract_df,
)
from compute_result.typing import Run, ProblemSpace, RunResult, Point, ResultRow
from problems.types import Suites
from utils.algorithms_data import Algorithms

EPOCH_LOSS_COL_NAME = "epoch"
INNER_EPOCH_LOSS_COL_NAME = "inner_epoch"
LOSS_COL_NAME = "loss"


def construct_table_name(run: Run, problem_space: ProblemSpace) -> str:
    algorithm_name, run_name = run
    suite, func_number, dim, instance = problem_space

    return (
        f"run{TABLE_NAME_BREAK}{algorithm_name.value}{TABLE_NAME_BREAK}{run_name}{TABLE_NAME_BREAK}{suite.value}"
        f"{TABLE_NAME_BREAK}{func_number}{TABLE_NAME_BREAK}{dim}{TABLE_NAME_BREAK}{instance}"
    )


def construct_loss_table_name(run: Run, problem: ProblemSpace):
    return f"{construct_table_name(run, problem)}_test_losses"


def extract_run_from_table_name(table_name: str) -> Tuple[Run, ProblemSpace]:
    _, algorithm_name, run_name, suite, func_number, dim, instance = table_name.split(
        TABLE_NAME_BREAK
    )
    return (
        (Algorithms(algorithm_name), run_name),
        (
            Suites(suite),
            int(func_number),
            int(dim),
            int(instance),
        ),
    )


class SQLiteStorage(ResultStore):
    def __init__(self, data_path: Path, logger: Logger = None):
        self.db_path = data_path
        self.logger = logger or logging.getLogger(__name__)

    @contextlib.contextmanager
    def get_sqlalchemy_conn(self):
        with get_sqlalchemy_conn(self.db_path) as conn:
            yield conn

    @contextlib.contextmanager
    def get_conn(self):
        with get_conn(self.db_path) as conn:
            yield conn

    @contextlib.contextmanager
    def get_cursor(self):
        with get_cursor(self.db_path) as cursor:
            yield cursor

    def storage_exists(self):
        return self.db_path.exists()

    def run_result(self, run: Run, problem_space: ProblemSpace) -> RunResult:
        table_name = construct_table_name(run, problem_space)
        return list_run_result(self.db_path, table_name)

    def is_run_exists(self, run: Run, problem_space: ProblemSpace):
        table_name = construct_table_name(run, problem_space)
        return is_table_exist(self.db_path, table_name)

    def min_max_from_space(
        self, problem_space: ProblemSpace
    ) -> Tuple[float, float, Point, Point]:
        return min_max_from_space(self.db_path, problem_space)

    def update_min(self, problem_space: ProblemSpace, min_value: float, min_point: Point):
        return update_min(self.db_path, self.logger, problem_space, min_value, min_point)

    def update_max(self, problem_space: ProblemSpace, max_value: float, max_point: Point):
        return update_max(self.db_path, self.logger, problem_space, max_value, max_point)

    def remove_run(self, run: Run, problem_space: ProblemSpace):
        table_name = construct_table_name(run, problem_space)
        return remove_table(self.db_path, self.logger, table_name)

    def store_step(
        self,
        run: Run,
        problem_space: ProblemSpace,
        budget: int,
        value: float,
        point: Point,
        algorithm_name: str,
    ):
        table_name = construct_table_name(run, problem_space)
        store_step(self.db_path, table_name, budget, value, point, algorithm_name)

    def store_run(self, run: Run, problem_space: ProblemSpace, run_result: RunResult):
        self.logger.info(f"Storing run {run}, {problem_space} {len(run_result)}")
        run_result_df = self._filter_multiple_budget(run_result, return_df=True)
        self.logger.info(f"Filter {run}, {problem_space} to {len(run_result_df)}")

        table_name = construct_table_name(run, problem_space)
        return store_run(self.db_path, table_name, run_result_df)

    def list_min_max_problems(self) -> List[ProblemSpace]:
        return list_min_max_problems(self.db_path)

    def list_runs(self) -> List[Tuple[Run, ProblemSpace]]:
        return [
            extract_run_from_table_name(table_name)
            for table_name in list_tables(self.db_path)
            if table_name.startswith("run")
        ]

    def rename_run(self, run: Run, new_run_name: str):
        with self.get_cursor() as cursor:
            # Get a list of all the tables in the database
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
            tables = cursor.fetchall()

            # Iterate over the tables and rename the ones corresponding to the old run
            for table in tables:
                table_name = table[0]
                if table_name == "min_max":
                    continue
                curr_run, problem = extract_run_from_table_name(table_name)
                if curr_run == run:
                    new_run = (curr_run[0], new_run_name)
                    new_table_name = construct_table_name(new_run, problem)
                    if self.is_run_exists(new_run, problem):
                        self.logger.info(f"Removing run {new_run}-{problem}")
                        self.remove_run(new_run, problem)
                    cursor.execute(f"ALTER TABLE `{table_name}` RENAME TO `{new_table_name}`")
                    self.logger.info(f"Renamed table '{table_name}' to '{new_table_name}'")

    def all_algorithms_in_run(self, run: Run, problems: List[ProblemSpace]) -> List[str]:
        def table_name_analyzer(table_name: str) -> Tuple[Run, ProblemSpace]:
            if not table_name.startswith("run"):
                return None
            return extract_run_from_table_name(table_name)

        return find_algorithms_in_table(self.db_path, run, table_name_analyzer)

    def problem_final_value_for_run(self, run: Run, problem: ProblemSpace) -> ResultRow:
        return last_record_from_run_result(self.db_path, construct_table_name(run, problem))

    def store_metric(self, metric_name: str, data: DataFrame):
        store_df(data, self.db_path, metric_name)

    def get_metric(self, metric_name: str) -> DataFrame:
        return extract_df(self.db_path, metric_name)

    def get_metrics_types(self, metric_initial: str) -> List[str]:
        tables = list_tables(self.db_path)
        return [table for table in tables if table.startswith(metric_initial)]
