import sqlite3
import os
import tqdm
import warnings


class SqliteDict:

    def __init__(self, path, mode='r', key_column='key', value_column='value', table='dict', check_same_thread=False, keep_open=True, manual_commit=False, use_wal=True, timeout=20):
        self.path = path
        self.check_same_thread = check_same_thread
        self.keep_open = keep_open
        self.timeout = timeout
        self.conn = None

        self.manual_commit = manual_commit
        self.uncommited_values = 0

        if self.keep_open:
            self._connect()

        self.key_column = key_column
        self.value_column = value_column
        self.table = table

        if mode == 'r':
            assert os.path.isfile(path), path

        if mode == 'c':
            self._create_table()
            if use_wal:
                self._set_wal()
            mode = 'w'
        self.mode = mode

    def _set_wal(self):
        query = "PRAGMA journal_mode=WAL;"
        self._execute(query)


    def __setitem__(self, key, value, value_column=None, cursor=None, commit=True, update_only=False):
        if value_column is None:
            value_column = self.value_column
        set_query = f'''
                    INSERT INTO {self.table} ({self.key_column}, {value_column})
                    VALUES (?, ?)
                    ON CONFLICT({self.key_column}) DO UPDATE SET {value_column} = excluded.{value_column};
                    '''
        values = (key, value)
        if update_only:
            set_query = f'''UPDATE {self.table} SET {value_column} = ? WHERE {self.key_column} = ?;'''
            values = (value, key)
        assert self.mode == 'w'
        self._execute(set_query, args=values, cursor=cursor, commit=commit)
        self.uncommited_values += 1

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def __delitem__(self, key):
        assert self.mode == 'w'
        delete_query = f'DELETE FROM {self.table} WHERE {self.key_column} = ?;'
        self._execute(delete_query, args=(key,))
        self.uncommited_values += 1

    def _maybe_close(self):
        if not self.keep_open:
            self.close()

    def __len__(self, where=None):
        count_query = f"SELECT COUNT(*) FROM {self.table}"
        if where:
            count_query += f' WHERE {where}'
        cursor = self._execute(count_query, return_cursor=True)
        length = cursor.fetchone()[0]
        self._maybe_close()
        return length

    def close(self, commit=False):
        if commit:
            self.commit()
        if self.conn is not None:
            self.conn.close()
        self.conn = None

    def _get_cursor(self):
        self._connect()
        return self.conn.cursor()

    def _commit(self):
        try:
            if not self.manual_commit:
                self.commit()
        except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
            warnings.warn(str(e))

    def commit(self):
        self.conn.commit()
        self.uncommited_values = 0

    def _connect(self):
        if self.conn is None:
            try:
                conn = sqlite3.connect(self.path,
                                       check_same_thread=self.check_same_thread,
                                       timeout=self.timeout)
            except sqlite3.OperationalError as e:
                print(self.path)
                raise e
            self.conn = conn
        return self.conn

    def get_column_names(self):
        query = f"PRAGMA table_info({self.table});"
        cursor = self._execute(query, return_cursor=True)
        columns_info = cursor.fetchall()
        column_names = [col[1] for col in columns_info]
        self._maybe_close()
        return column_names

    def update(self, other, value_column=None, progress=False, update_only=False):
        cursor = self._get_cursor()
        if progress:
            func = tqdm.tqdm
        else:
            func = lambda e: e
        for key, value in func(other.items()):
            self.__setitem__(key, value, value_column=value_column, cursor=cursor, commit=False, update_only=update_only)
        self._commit()

    def __contains__(self, item):
        exists_query = f"SELECT EXISTS(SELECT 1 FROM {self.table} WHERE {self.key_column} = ?)"
        cursor = self._execute(exists_query, item, return_cursor=True)
        result = cursor.fetchone()[0]
        self._maybe_close()
        return bool(result)

    def _get_value_columns(self, value_columns):
        if value_columns is not None:
            assert type(value_columns) == tuple or type(value_columns) == list
            value_columns = ', '.join(value_columns)
        else:
            value_columns = self.value_column
        return value_columns

    def _get_select_keys_query(self, where):
        select_query = f'SELECT {self.key_column} FROM {self.table}'
        if where is not None:
            select_query += ' WHERE ' + where
        return select_query

    def _get_select_values_query(self, value_columns, where):
        value_columns = self._get_value_columns(value_columns)
        select_query = f'SELECT {value_columns} FROM {self.table}'
        if where is not None:
            select_query += ' WHERE ' + where
        return select_query

    def _get_select_items_query(self, value_columns, where):
        value_columns = self._get_value_columns(value_columns)
        select_query = f'SELECT {self.key_column}, {value_columns} FROM {self.table}'
        if where is not None:
            select_query += ' WHERE ' + where
        return select_query

    def keys(self, where=None):
        if not self.keep_open:
            warnings.warn('iterator not supported when keep_open is True')
            keys = self.all_keys(where=where)
        else:
            select_query = self._get_select_keys_query(where)
            cursor = self._execute(select_query, return_cursor=True)
            keys = cursor
        for key in keys:
            key = key[0] if type(key) == tuple and len(key) == 1 else key
            yield key
        self._maybe_close()

    def all_keys(self, where=None):
        select_query = self._get_select_keys_query(where)
        cursor = self._execute(select_query, return_cursor=True)
        keys = [key[0] for key in cursor.fetchall()]
        self._maybe_close()
        return keys

    def values(self, value_columns=None, where=None):
        if not self.keep_open:
            warnings.warn('iterator not supported when keep_open is True')
            values = self.all_values(value_columns, where)
        else:
            select_query = self._get_select_values_query(value_columns, where)
            cursor = self._execute(select_query, return_cursor=True)
            values = cursor
        for value in values:
            value = value[0] if type(value) == tuple and len(value) == 1 else value
            yield value
        self._maybe_close()

    def all_values(self, value_columns=None, where=None):
        select_query = self._get_select_values_query(value_columns=value_columns, where=where)
        cursor = self._execute(select_query, return_cursor=True)
        values = [value[0] if len(value) == 1 else value for value in cursor.fetchall()]
        self._maybe_close()
        return values

    def items(self, value_columns=None, where=None):
        if not self.keep_open:
            warnings.warn('iterator not supported when keep_open is True')
            items = self.all_items(value_columns=value_columns, where=where)
        else:
            select_query = self._get_select_items_query(value_columns=value_columns, where=where)
            cursor = self._execute(select_query, return_cursor=True)
            items = cursor

        for key, *values in items:
            yield key, *values
        self._maybe_close()

    def all_items(self, value_columns=None, where=None):
        select_query = self._get_select_items_query(value_columns=value_columns, where=where)
        cursor = self._execute(select_query, return_cursor=True)
        values = list(cursor.fetchall())
        self._maybe_close()
        return values

    def get(self, key, default=None, value_columns=None):
        try:
            value = self.__getitem__(key, value_columns=value_columns)
        except KeyError:
            return default
        return value

    def get_multiple(self, keys, value_columns=None):
        assert type(keys) in [tuple, list]
        n = len(keys)
        value_columns = self._get_value_columns([self.key_column] + value_columns)
        placeholders = ', '.join(['?'] * n)
        select_query = f'SELECT {value_columns} FROM {self.table} WHERE {self.key_column} in ({placeholders})'
        cursor = self._execute(select_query, keys, return_cursor=True)
        values = cursor.fetchall()
        result = dict()
        for key, *v in values:
            result[key] = v
        return [result[key] for key in keys]

    def __getitem__(self, key, value_columns=None):
        select_query = self._get_select_values_query(value_columns, where=f"{self.key_column} = ?")
        cursor = self._execute(select_query, key, return_cursor=True)
        value = cursor.fetchone()
        self._maybe_close()
        if value is None:
            raise KeyError(f'Key "{key}" not found')
        if len(value) == 1:
            return value[0]
        return value


    def _execute(self, query, args=tuple(), cursor=None, commit=True, return_cursor=False):
        if type(args) not in [tuple, list]:
            args = (args,)

        if cursor is None:
            cursor = self._get_cursor()
        # conn active at this point
        try:
            cursor.execute(query, args)
        except sqlite3.OperationalError as e:
            print(self.path)
            raise e
        if commit:
            self._commit()
        if return_cursor:
            return cursor
        self._maybe_close()

    def _create_table(self):
        create_table_query = f'''CREATE TABLE IF NOT EXISTS {self.table} (
                                {self.key_column} TEXT PRIMARY KEY,
                                {self.value_column} TEXT
                            );'''
        self._execute(create_table_query)
