#!/usr/bin/env python3

import logging
import sys
import argparse
import json


from sqlalchemy import create_engine, inspect, text

# from sqlalchemy import INTEGER, NUMERIC, VARCHAR, TIMESTAMP
from sqlalchemy.sql.sqltypes import (
    INTEGER,
    REAL,
    VARCHAR,
    TIMESTAMP,
    TEXT,
    DOUBLE,
    DATE,
)

logger = logging.getLogger(__name__)


def create_abstract_graph(db_uri):

    entities = []
    relationships = []
    fk_relationships = []
    abstract_graph = {
        "entities": entities,
        "relationships": relationships,
        "fk_relationships": fk_relationships,
    }

    engine = create_engine(db_uri)
    inspector = inspect(engine)
    connection = engine.connect()

    for table_name in inspector.get_table_names():
        # table_description = inspector.get_table_comment(table_name)

        # create an entity
        entity = {}
        primary_keys = inspector.get_pk_constraint(table_name)
        if primary_keys and len(primary_keys["constrained_columns"]):
            table_id = ",".join(primary_keys["constrained_columns"])
            entity["id"] = table_id

        entity["type"] = table_name
        entities.append(entity)

        # create entity properties(columns)
        properties = {}
        columns = inspector.get_columns(table_name)
        for column in columns:
            column_name = column["name"]
            column_type = column["type"]
            column_dict = {}

            try:
                if isinstance(column_type, (INTEGER, REAL, DOUBLE)):
                    column_dict["type"] = str(column_type).lower()
                    query = text(
                        f'SELECT MIN("{column_name}"), MAX("{column_name}"), AVG("{column_name}") FROM "{table_name}"'
                    )
                    result = connection.execute(query).fetchall()
                    column_dict["min"] = result[0][0]
                    column_dict["max"] = result[0][1]
                    column_dict["mean"] = result[0][2]
                elif isinstance(column_type, (TIMESTAMP, DATE)):
                    column_dict["type"] = "date"
                    query = text(
                        f'SELECT MIN("{column_name}"), MAX("{column_name}") FROM "{table_name}"'
                    )
                    result = connection.execute(query).fetchall()
                    column_dict["min"] = result[0][0]
                    column_dict["max"] = result[0][1]
                elif isinstance(column_type, (VARCHAR, TEXT)):
                    column_dict["type"] = "text"
                    # Query to get 50 distinct values for each column
                    query = text(
                        f'SELECT DISTINCT "{column_name}" FROM "{table_name}" LIMIT 50'
                    )
                    result = connection.execute(query).fetchall()
                    values = [row[0] for row in result]  # Extract values from result
                    column_dict["values"] = values
                else:
                    logger.error(
                        "Unhandled column type '%s' for column '%s'",
                        column_type,
                        column_name,
                    )

            except Exception as ex:
                logger.error("Failed query exececution:%s", str(ex))
                raise ex

            properties[column_name] = column_dict

        entity["properties"] = properties

        # now handle and FK relationships
        foreign_keys = inspector.get_foreign_keys(table_name)
        for fk in foreign_keys:
            # print(fk)
            for i, src_key in enumerate(fk["constrained_columns"]):
                relationship = {}
                target_table = fk["referred_table"]
                target_key = fk["referred_columns"][i]
                relationship["source"] = f"{table_name}.{src_key}"
                relationship["target"] = f"{target_table}.{target_key}"
                relationship["type"] = "foreignKey"

                fk_relationships.append(relationship)

    return abstract_graph


def run(dbfile, graphfile):

    # load metadata of source domain
    # source_db_dialect = "sqlite"
    source_db_uri = f"sqlite:///{dbfile}"
    abstract_graph = create_abstract_graph(source_db_uri)

    with open(graphfile, "w", encoding="utf-8") as f:
        json.dump(abstract_graph, f, indent=2)


def parse_args(argv=None):
    """Command line options."""
    program_name = __name__
    program_desc = "Create abstract graph from an sqlite database"

    if argv is None:
        argv = sys.argv[1:]

    # setup option parser
    parser = argparse.ArgumentParser(prog=program_name, description=program_desc)

    # Inputs
    parser.add_argument(
        "-e",
        "--env",
        type=str,
        default=None,
        help="Env file to load settings/credentials, default(.env)",
    )

    parser.add_argument(
        "--target-db-file",
        type=str,
        dest="dbfile",
        required=True,
        help="Target SQLITE DB file",
    )

    # Outputs
    parser.add_argument(
        "--abstract-graph",
        type=str,
        dest="graphfile",
        required=True,
        help="AbstractGraph output json file",
    )

    args = parser.parse_args()

    return args


def main(argv=None):

    # logging.basicConfig(
    #    level=logging.INFO,
    #    format="%(levelname)s:%(name)s:%(asctime)s: %(message)s",
    #    datefmt="%Y-%m-%d %H:%M:%S",
    # )

    logging.basicConfig(
        level=logging.INFO,
        format="%(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    args = parse_args(argv)
    # config = dotenv_values(args.env)

    run(args.dbfile, args.graphfile)


if __name__ == "__main__":
    sys.exit(main())
