from sqlalchemy import create_engine
from sqlalchemy import text
from utils import log
import uuid


class SqliteAlchemyDatabase:
    
    def __init__(self, db):
        
        self.logger = log.get_loguru()
        self.logger.info("Init SqliteAlchemyDatabase {}".format(db))
        self.sqlalchemy_url = "sqlite:///{}".format(db)
        self.engine = create_engine(
            self.sqlalchemy_url,
            pool_size=64,
            max_overflow=96,
            connect_args={'timeout': 20}
        )

class AgentAlchemyDatabase(SqliteAlchemyDatabase):

    def init_table(self, table_name, feild_list):
        with self.engine.connect() as connect:
            raw_text = """CREATE TABLE IF NOT EXISTS {}(
                    session_id TEXT PRIMARY KEY NOT NULL""".format(
                table_name
            )

            for feild in feild_list:
                raw_text += """,{}  TEXT""".format(feild)
            raw_text += """);"""
            
            sql_text = text(raw_text)
            connect.execute(sql_text)
            connect.commit()

    def get_feild_by_id(self, table_name, feild_name, session_id):
        with self.engine.connect() as connect:
            sql_text = text(
                """SELECT {} FROM {} WHERE session_id= :session_id""".format(
                    feild_name, table_name
                )
            )
            
            param = {"session_id": session_id}
            rows = connect.execute(sql_text, param)
            result = {}
            for row in rows:
                if row[0]:
                    result = {feild_name: row[0]}
            return result
    
    def exist(self, table_name, session_id):
        with self.engine.connect() as connect:
            sql_text = text(
                """SELECT COUNT(*) FROM {} WHERE session_id= :session_id""".format(
                    table_name
                )
            )
            
            param = {"session_id": session_id}
            rows = connect.execute(sql_text, param)
            if rows.fetchone()[0]:
                return True
            else:
                return False

    def save_feild(self, table_name, feild_name, session_id, feild_info):
        with self.engine.connect() as connect:
            sql_text = text(
                """INSERT INTO {} (session_id, {}) VALUES (:session_id, :{})""".format(
                    table_name, feild_name, feild_name
                )
            )

            param = {"session_id": session_id, feild_name: feild_info}
            connect.execute(sql_text, param)
            connect.commit()

    def update_feild(self, table_name, feild_name, session_id, feild_info):
        with self.engine.connect() as connect:
            sql_text = text(
                """UPDATE {} SET {}=:{} WHERE session_id=:session_id""".format(
                    table_name, feild_name, feild_name
                )
            )
            
            param = {"session_id": session_id, feild_name: feild_info}
            connect.execute(sql_text, param)
            connect.commit()

    def get_feild_by_param(self, table_name, feild_name, param):
        '''
            retrieve data using additional keys.
            param is a dict that includes an external key.
        '''
        with self.engine.connect() as connect:
            sql_param = {}
            assert len(param) == 3
            for idx, k in enumerate(param.keys()):
                sql_param["k{}".format(idx + 1)] = k
            sql_param["feild_name"] = feild_name
            sql_param["table_name"] = table_name
            
            sql_text = text(
                """SELECT {feild_name} FROM {table_name} WHERE {k1}=:{k1} and {k2}=:{k2} and {k3}=:{k3}""".format(**sql_param)
            )

            # self.logger.info("db load {}".format(feild_name))
            rows = connect.execute(sql_text, param)
            result = {}
            for row in rows:
                if row[0]:
                    result = {feild_name: row[0]}
            return result
    
    def exist_by_param(self, table_name, param):
        with self.engine.connect() as connect:
            sql_param = {}
            assert len(param) == 3
            for idx, k in enumerate(param.keys()):
                sql_param["k{}".format(idx + 1)] = k
            sql_param["table_name"] = table_name
            
            sql_text = text(
                """SELECT COUNT(*) FROM {table_name} WHERE {k1}= :{k1} and {k2}= :{k2} and {k3}= :{k3}""".format(**sql_param)
            )
            
            rows = connect.execute(sql_text, param)
            if rows.fetchone()[0]:
                return True
            else:
                return False
    
    def save_feild_by_param(self, table_name, feild_name, feild_info, param):
        with self.engine.connect() as connect:
            
            sql_param = {}
            assert len(param) == 3
            for idx, k in enumerate(param.keys()):
                sql_param["k{}".format(idx + 1)] = k
            sql_param["feild_name"] = feild_name
            sql_param["table_name"] = table_name
            
            sql_text = text(
                """INSERT INTO {table_name} (session_id, {feild_name}, {k1}, {k2}, {k3}) VALUES (:session_id, :{feild_name}, :{k1}, :{k2}, :{k3})""".format(**sql_param)
            )
            
            # self.logger.info("db save {}".format(feild_info))
            param["session_id"] = uuid.uuid4().hex
            param[feild_name] = feild_info
            connect.execute(sql_text, param)
            connect.commit()

    def update_feild_by_param(self, table_name, feild_name, feild_info, param):
        with self.engine.connect() as connect:

            sql_param = {}
            assert len(param) == 3
            for idx, k in enumerate(param.keys()):
                sql_param["k{}".format(idx + 1)] = k
            sql_param["feild_name"] = feild_name
            sql_param["table_name"] = table_name
            
            sql_text = text(
                """UPDATE {table_name} SET {feild_name}=:{feild_name} WHERE {k1}=:{k1} and {k2}=:{k2} and {k3}=:{k3}""".format(**sql_param)
            )

            # self.logger.info("db update {}".format(feild_info))
            param[feild_name] = feild_info
            connect.execute(sql_text, param)
            connect.commit()
    