"""
Experiment Database Manager with File Locking.
"""
import json
import os
import sqlite3
from typing import List, Dict, Union
from filelock import FileLock
import pandas as pd


def load_experiments(exp_db_path):
    """Load experiment statuses from shared storage, creating the file if needed."""
    if not os.path.exists(exp_db_path):
        return []
    with open(exp_db_path, "r") as f:
        return json.load(f)


def save_experiments(experiments, exp_db_path, encoder_cls=None):
    """Save updated experiment statuses to shared storage."""
    with open(exp_db_path, "w") as f:
        json.dump(experiments, f, indent=4, cls=encoder_cls)


class JSONExperimentDB:
    """
    Experiment database that summarizes the settings for each run.
    """

    def __init__(self, exp_db_path, exp_db_lock_path, param_clms: list, encoder_cls=None):
        """
        Args:
            exp_db_path (str): Path to the experiment database.
            exp_db_lock_path (str): Path to the lock file for the experiment database.
            param_clms (list): List of parameter column names used for identifying the experiments.
            encoder_cls: Optional encoder class for custom serialization.
        """
        self.exp_db_path = exp_db_path
        self.exp_db_lock_path = exp_db_lock_path
        self.param_clms = param_clms
        self.encoder_cls = encoder_cls

    def add_experiments(self, new_experiments):
        """
        Add parameters to the experiment database.

        Args:
            parameters (list[dict]): List of parameters for each run.
        """
        with FileLock(self.exp_db_lock_path):
            db_experiments = load_experiments(self.exp_db_path)

            for exp in new_experiments:
                # check if this setting is already in the database
                matched = self._search_experiments(db_experiments, **{clm: exp[clm] for clm in self.param_clms})
                if len(matched) == 0:
                    # if not, append it to the database
                    exp['status'] = 'pending'
                    db_experiments.append(exp)
            save_experiments(db_experiments, self.exp_db_path, self.encoder_cls)

    def get_next_experiments(self, batch_size, re_run_error=False, **filters):
        """
        Get the parameters for the next batch and mark them as running.

        Args:
            batch_size (int): Number of samples to get.
            re_run_error (bool): If True, re-run the experiments that are marked as error.
            **filters: Filters for the experiments. The key is the column name and the value is the value to filter.

        Returns:
            parameters_batch: List of dict that stores parameters for each run.
        """
        with FileLock(self.exp_db_lock_path):
            experiments = load_experiments(self.exp_db_path)

            # collect the experiment that matches the filters
            matched_experiments = self._search_experiments(experiments, **filters)

            batch = []
            for exp in matched_experiments:
                if exp["status"] == "pending":
                    batch.append(exp)
                    exp["status"] = "running"
                    if len(batch) == batch_size:
                        break
                elif exp["status"] == "error" and re_run_error:
                    batch.append(exp)
                    exp["status"] = "running"
                    if len(batch) == batch_size:
                        break

            save_experiments(experiments, self.exp_db_path, self.encoder_cls)
            return batch
        
    def mark_experiment_status(self, experiments, status):
        """
        Update the status of the experiments in the database.
        """
        with FileLock(self.exp_db_lock_path):
            db_experiments = load_experiments(self.exp_db_path)
            for exp in experiments:
                # search experiment that matches the given parameters
                matched = self._search_experiments(db_experiments, **{clm: exp[clm] for clm in self.param_clms})
                if len(matched) == 0:
                    raise ValueError(f"Experiment not found in the db: {exp}")
                if len(matched) > 1:
                    raise ValueError(f"Multiple experiments found in the db: {matched}")
                
                # mark this experiment as finished
                # remove the experiment from the list because the reference is lost
                #db_experiments.remove(matched[0])  
                matched[0]['status'] = status
                #db_experiments.append(matched[0])
            save_experiments(db_experiments, self.exp_db_path, self.encoder_cls)

    def update_experiment_info(self, experiments):
        """
        Update the experiment information in the database.
        """
        with FileLock(self.exp_db_lock_path):
            db_experiments = load_experiments(self.exp_db_path)
            for exp in experiments:
                # search experiment that matches the given parameters
                matched = self._search_experiments(db_experiments, **{clm: exp[clm] for clm in self.param_clms})
                if len(matched) == 0:
                    raise ValueError(f"Experiment not found in the db: {exp}")
                if len(matched) > 1:
                    raise ValueError(f"Multiple experiments found in the db: {matched}")
                
                # update the experiment information
                #db_experiments.remove(matched[0])
                matched[0].update(exp)
                #db_experiments.append(matched[0])

            save_experiments(db_experiments, self.exp_db_path, self.encoder_cls)

    def _search_experiments(self, experiments, **filters):
        """
        From experiments, earch experiments that match the filters.
        """
        matched = []
        for exp in experiments:
            if all(exp[key] == value for key, value in filters.items()):
                matched.append(exp)
        return matched


class ExperimentDB:
    """
    A flexible database manager for machine learning experiments.
    Use file locking to ensure safe concurrent access to the SQLite database.
    Supports dynamic experiment parameters (int, float, str) and evaluation metrics.

    Attributes:
        db_path (str): Path to the SQLite database file.
        lock_path (str): Path to the lock file for safe concurrent access.
        exp_param_names (Dict[str, str]): Names and types of experiment parameters.
        metric_names (List[str]): Names of evaluation metrics.
    """

    def __init__(self, db_path: str, exp_param_names: Dict[str, str], metric_names: List[str]):
        """
        Initializes the ExperimentDB class and creates the experiments table if it doesn't exist.

        Args:
            db_path (str): Path to the SQLite database.
            exp_param_names (Dict[str, str]): Experiment parameter names with types (e.g., {'layer': 'TEXT', 'noise_level': 'REAL'}).
            metric_names (List[str]): Names of evaluation metrics (e.g., ['accuracy', 'loss']).
        """
        self.db_path = db_path
        self.lock_path = db_path + ".lock"
        self.exp_param_names = exp_param_names
        self.metric_names = metric_names
        self._initialize_database()

    def _initialize_database(self):
        """Creates the experiments table dynamically based on provided parameter types."""
        with FileLock(self.lock_path):
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            # Validate parameter types
            valid_sql_types = {"TEXT", "REAL", "INTEGER"}
            for param, dtype in self.exp_param_names.items():
                if dtype.upper() not in valid_sql_types:
                    raise ValueError(f"Invalid SQL type for '{param}': {dtype}. Allowed types: {valid_sql_types}")

            param_columns = ", ".join(f"{param} {dtype}" for param, dtype in self.exp_param_names.items())
            metric_columns = ", ".join(f"{metric} REAL" for metric in self.metric_names)

            # Ensure the UNIQUE constraint for duplicate prevention
            unique_constraint = ", ".join(self.exp_param_names.keys())

            sql_query = f"""
            CREATE TABLE IF NOT EXISTS experiments (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                {param_columns},
                status TEXT DEFAULT 'pending',
                {metric_columns},
                UNIQUE ({unique_constraint})  -- Prevent duplicate experiments
            );
            """
            cursor.execute(sql_query)
            conn.commit()
            conn.close()

    def add_experiments(self, exp_list: List[Dict[str, Union[str, float, int]]]):
        """
        Adds new experiments to the database safely, avoiding duplicates even in multi-process execution.

        Args:
            exp_list (List[Dict[str, Union[str, float, int]]]): List of experiment settings where keys match exp_param_names.
        """
        with FileLock(self.lock_path):
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            param_columns = ", ".join(self.exp_param_names.keys())
            placeholders = ", ".join("?" for _ in self.exp_param_names)

            for exp in exp_list:
                values = tuple(exp[param] for param in self.exp_param_names.keys())
                cursor.execute(f"""
                INSERT OR IGNORE INTO experiments ({param_columns}, status) 
                VALUES ({placeholders}, 'pending')
                """, values)

            conn.commit()
            conn.close()
        
    def get_next_experiments(self, batch_size: int, include_error: bool = False, **filters) -> List[Dict[str, Union[str, float, int]]]:
        """
        Fetches the next batch of pending experiments (or errored ones if `include_error=True`), optionally filtering by provided parameters.

        Args:
            batch_size (int): Number of experiments to fetch.
            include_error (bool): Whether to include experiments that previously encountered errors.
            **filters: Optional key-value pairs for filtering (e.g., layer="conv1", noise_level=0.1).

        Returns:
            List[Dict[str, Union[str, float, int]]]: A list of dictionaries, each containing experiment parameters and its ID.
        """
        with FileLock(self.lock_path):
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            # Construct WHERE clause
            status_filter = "status = 'pending'"
            if include_error:
                status_filter = "status IN ('pending', 'error')"

            where_clauses = [status_filter]
            values = []

            for key, value in filters.items():
                if key in self.exp_param_names:  # Ensure filtering is on valid parameters
                    where_clauses.append(f"{key} = ?")
                    values.append(value)

            where_clause = " AND ".join(where_clauses)

            param_columns = ", ".join(self.exp_param_names.keys())
            query = f"""
            SELECT id, {param_columns} FROM experiments WHERE {where_clause} LIMIT ?
            """
            
            values.append(batch_size)  # Add batch_size as the last parameter
            cursor.execute(query, values)
            
            batch = cursor.fetchall()

            for exp_id, *_ in batch:
                cursor.execute("UPDATE experiments SET status = 'running' WHERE id = ?", (exp_id,))

            conn.commit()
            conn.close()

            return [{"id": row[0], **dict(zip(self.exp_param_names.keys(), row[1:]))} for row in batch]


    def _update_experiment(self, exp_ids: List[int], update_values: Dict[str, Union[str, float]]):
        """
        Generic method to update any column(s) for a batch of experiments.

        Args:
            exp_ids (List[int]): List of experiment IDs.
            update_values (Dict[str, Union[str, float]]): Dictionary of column names and their new values.
        """
        if not exp_ids:
            return  # No experiments to update

        with FileLock(self.lock_path):
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            update_clause = ", ".join(f"{key} = ?" for key in update_values.keys())
            values = list(update_values.values())

            for exp_id in exp_ids:
                cursor.execute(f"""
                UPDATE experiments SET {update_clause} WHERE id = ?
                """, values + [exp_id])

            conn.commit()
            conn.close()

    def update_experiment_results(self, exp_ids: List[int], results: List[Dict[str, float]]):
        """
        Updates the database with experiment results and marks experiments as completed.

        Args:
            exp_ids (List[int]): List of experiment IDs.
            results (List[Dict[str, float]]): List of result dictionaries with metric values.
        """
        for i, exp_id in enumerate(exp_ids):
            update_values = {"status": "completed", **results[i]}
            self._update_experiment([exp_id], update_values)

    def update_experiment_status(self, exp_ids: List[int], status: str):
        """
        Updates the status of multiple experiments.

        Args:
            exp_ids (List[int]): List of experiment IDs.
            status (str): New status to set ('pending', 'running', 'completed', or 'error').
        """
        self._update_experiment(exp_ids, {"status": status})

    def to_dataframe(self) -> pd.DataFrame:
        """
        Converts the entire experiment database into a pandas DataFrame.

        Returns:
            pd.DataFrame: A DataFrame containing all experiments with parameters, status, and metrics.
        """
        with FileLock(self.lock_path):
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            # Fetch all data
            cursor.execute("SELECT * FROM experiments")
            rows = cursor.fetchall()

            # Get column names
            column_names = [desc[0] for desc in cursor.description]

            conn.close()

        # Convert to DataFrame
        return pd.DataFrame(rows, columns=column_names)


# Example Usage
if __name__ == "__main__":
    db = ExperimentDB(
        db_path="temp.db",
        exp_param_names={"layer": "TEXT", "noise_level": "REAL"},
        metric_names=["metric1", "metric2"]
    )

    def show_db():
        conn = sqlite3.connect("temp.db")
        cursor = conn.cursor()
        cursor.execute("SELECT * FROM experiments")
        rows = cursor.fetchall()
        print('All rows in the experiments table:')
        for row in rows:
            print(row)
        conn.close()

    # add experiments
    experiments = [
        {"layer": "conv1", "noise_level": 0.1},
        {"layer": "conv1", "noise_level": 0.2},
        {"layer": "conv2", "noise_level": 0.1},
    ]
    db.add_experiments(experiments)
    show_db()

    # get two experiments
    batch = db.get_next_experiments(batch_size=2)
    print('Fetched experiments:', batch)
    print('After fetching experiments:')
    show_db()

    # update the results
    exp_ids = [exp["id"] for exp in batch]
    results = [{"metric1": exp["noise_level"] * 100, "metric2": len(exp["layer"]) * 10} for exp in batch]
    db.update_experiment_results(exp_ids, results)
    print('After updating results:')
    show_db()

    # update the status
    db.update_experiment_status(exp_ids, "error")
    print('After updating status:')
    show_db()

    # revert to pending and fetch all experiments
    db.update_experiment_status(exp_ids, "pending")
    batch = db.get_next_experiments(batch_size=3)
    print('Fetched all experiments:', batch)
    print('After fetching all experiments:')
    show_db()

    # finish all experiments
    exp_ids = [exp["id"] for exp in batch]
    db.update_experiment_status(exp_ids, "completed")
    print('After finishing all experiments:')
    show_db()

    # try to get more experiments
    batch = db.get_next_experiments(batch_size=3)
    assert not batch, "Expected no more experiments to fetch."

    # convert to DataFrame
    df = db.to_dataframe()
    print('Database as DataFrame:')
    print(df)

    # cleanup
    import os
    os.remove("temp.db")
    os.remove("temp.db.lock")
    print('Database files removed.')


