#!/usr/bin/env python3

"""_summary_"""
import json
import sys
import os
import argparse
import logging
import random
import requests

from openai import OpenAI

from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
from langchain_ibm import WatsonxLLM

from sqlalchemy import create_engine, inspect, text

from dotenv import dotenv_values
from pydantic import BaseModel
import pandas as pd

# import pprint

from sqlglot import optimizer, parse_one, exp
from sqlglot.optimizer import traverse_scope

from sqlglot.optimizer.qualify import qualify


logger = logging.getLogger("convert_domain.py")

# ## Load source domain
# ## Load target domain
# ## Load metadata of source and target domain
# ## Get template of source query
# ## Match items and params of source template to target domain
# ## Generate target query
# ## Convert target query to meaningful natural language


class TargetSchema(BaseModel):
    """Used to get LLM response as json

    Args:
        BaseModel (_type_): _description_
    """

    source_question: str
    source_sql: str
    target_sql: str
    target_question: str


def load_src_schema(db_uri):
    engine = create_engine(db_uri)
    inspector = inspect(engine)
    schema_dict = {}
    connection = engine.connect()
    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        column_dict = {}
        primary_keys = inspector.get_pk_constraint(table_name)
        foreign_keys = inspector.get_foreign_keys(table_name)
        for column in columns:
            column_name = column["name"]
            # Query to get 50 distinct values for each column
            query = text(
                f'SELECT DISTINCT "{column_name}" FROM "{table_name}" LIMIT 50'
            )
            try:
                result = connection.execute(query).fetchall()
                distinct_values = [
                    row[0] for row in result
                ]  # Extract values from result
            except Exception as ex:
                logger.error("Failed query exececution:%s", str(ex))
                raise ex
            column_dict[column["name"].lower()] = {
                "type": str(column["type"]),
                "primary_key": column["primary_key"],
                "sample_values": distinct_values,
            }

        schema_dict[table_name.lower()] = {
            "columns": column_dict,
            "primary_keys": primary_keys,
            "foreign_keys": foreign_keys,
        }

    return schema_dict


def convert_source_schema(schema):
    """
    Convert schema to a sqlglot compatible style

    Args:
        schema (_type_):

    Returns:
        _type_: _description_
    """

    if schema is None:
        return schema

    sqlglot_schema = {}
    for table in schema.keys():
        sqlglot_schema[table] = {}
        columns = schema[table]["columns"]
        for column in columns.keys():
            column_type = columns[column]["type"]
            sqlglot_schema[table][column.lower()] = column_type

    return sqlglot_schema


def create_service_abstract_graph(source_schema):

    abstract_graph = {}

    #    sqlglot_schema = {}
    #    for table in schema.keys():
    ##        sqlglot_schema[table] = {}
    #        columns = schema[table]["columns"]
    #        for column in columns.keys():
    #            column_type = columns[column]["type"]
    #            sqlglot_schema[table][column.lower()] = column_type#

    return abstract_graph


def convert_target_graph(graphfile, keep_entities=None, verbose=False):
    """Convert the target Abstract Entity/Relationship graph into a schema format
        needed by local code.
        Currently, only entities from the ER graph are converted.

    Args:
        graph (dict): target Entity Relationship graph

    Returns:
        dict: schema needed by local code
    """

    if verbose:
        logger.info("Reading target graph file '%s'", graphfile)
    with open(graphfile, "r", encoding="utf-8") as f:
        graph = json.load(f)

    target_schema = {}
    entities = graph["entities"]

    for entity in entities:
        table_name = entity["type"].lower()  # force to lowercase

        if keep_entities and table_name not in keep_entities:
            continue

        target_schema[table_name] = {
            "id": entity["id"].lower(),
            "table_name": table_name,
            "description": entity["description"] if "description" in entity else None,
            "columns": {},
        }

        columns = target_schema[table_name]["columns"]
        for col_name, col_info in entity["properties"].items():
            col_name = col_name.lower()  # force to lower case
            columns[col_name] = {}
            col_type = col_info["type"]
            if col_type.lower() in [
                "integer",
                "bigint",
                "smallint",
                "double",
                "float",
                "real",
            ]:
                col_type = "numeric"

            columns[col_name]["type"] = col_type
            for k, v in col_info.items():
                if k not in ["type"]:
                    columns[col_name][k] = v

    # logger.info("Input Target schema has '%d' tables:", len(target_schema))
    # for table_name, table_info in target_schema.items():
    #     logger.debug("Table '%s' has '%d' columns", table_name, len(table_info["columns"]))
    # target_schema = {k: v for k, v in target_schema.items() if k in keep_entities}

    if verbose:
        logger.info("Filtered Target schema has '%d' tables:", len(target_schema))
    for table_name, table_info in target_schema.items():
        if verbose:
            logger.info(
                "Table '%s' has '%d' columns", table_name, len(table_info["columns"])
            )

    return target_schema


def extract_sql_template(schema, sql, dialect):

    ast = fix_source_sql(schema, sql, dialect)

    # find all tables, map source table aliases to actual source table names
    table_alias_dict = {}
    for table in ast.find_all(exp.Table):
        table_alias_dict[table.alias_or_name.lower()] = table.name.lower()

    literal_map = {}
    for literal in ast.find_all(exp.Literal):
        # WHERE country = 'France'
        if not isinstance(literal.parent.this, exp.Column):
            continue

        column_table = literal.parent.this.table
        column_name = literal.parent.this.name
        table_name = (
            table_alias_dict[column_table]
            if column_table in table_alias_dict
            else column_table
        )
        type_of_column = schema[table_name]["columns"][column_name]["type"]
        literal_map[literal.name] = {
            "column_name": column_name,
            "column_type": type_of_column,
            "table_alias": column_table,
            "table_name": table_name,
        }
    #  print("\n\nliteral_map", literal_map)

    table_column_aliases = create_table_col_aliases(schema, ast)

    # literal_map {
    # '2009': {'column_name': 'open_year', 'column_type': 'TEXT','table_alias': 't3', 'table_name': 'museum'},
    # '2011': {'column_name': 'open_year', 'column_type': 'TEXT', 'table_alias': 't3', 'table_name': 'museum'}
    # }
    for literal, _ in literal_map.items():
        literal_table_name = literal_map[literal]["table_name"]
        literal_column_name = literal_map[literal]["column_name"]
        literal_index = len(
            table_column_aliases[literal_table_name]["columns"][literal_column_name][
                "values"
            ]
        )
        table_column_aliases[literal_table_name]["columns"][literal_column_name][
            "values"
        ].append(
            {
                "literal_value": literal,
                "alias": "VALUE_" + str(literal_index),
                "literal_type": literal_map[literal]["column_type"],
            }
        )

    parsed_sql = ast
    # the order of replacement is important, literals, columns, tables
    # print(repr(parsed_sql))
    parsed_sql = replace_literals(parsed_sql, table_alias_dict, table_column_aliases)
    # print(repr(parsed_sql))

    parsed_sql = replace_columns(parsed_sql, table_alias_dict, table_column_aliases)
    # print(repr(parsed_sql))

    parsed_sql = replace_tables(parsed_sql, table_column_aliases)
    # print(repr(parsed_sql))
    # print(parsed_sql)

    # remove aliases from template query
    # sql_template = parse_one(str(parsed_sql))
    sql_template = parsed_sql
    # print(repr(sql_template))
    # print(sql_template)
    sql_template = sql_template.transform(
        lambda node: node.this if isinstance(node, exp.Alias) else node
    )
    # print(repr(sql_template))
    # print(sql_template)

    logger.info("Templatized SQL: %s", sql_template)

    return (table_column_aliases, sql_template)


def replace_tables(parsed_sql, table_columns_alias):
    for table in parsed_sql.find_all(exp.Table):
        table_name = table.this.name.lower()
        table_alias = table_columns_alias[table_name]["alias"]
        table.this.replace(table_alias)
        table.set("alias", None)  # must remove old alias

    return parsed_sql


def replace_columns(parsed_sql, table_alias_dict, table_columns_alias):

    for scope in traverse_scope(parsed_sql):
        for column in scope.columns:
            if not column.table:
                message = f"Cannot find table for column '{column.name}', parent '{column.parent}', '{parsed_sql}', skipping..."
                raise ValueError(message)

            table_name = (
                table_alias_dict[column.table]
                if column.table in table_alias_dict
                else column.table
            )
            table_alias = table_columns_alias[table_name]["alias"]
            column_alias = table_columns_alias[table_name]["columns"][
                column.name.lower()
            ]["alias"]
            column_replace_name = column_alias
            column.replace(exp.Column(this=column_replace_name, table=table_alias))
            column.set("alias", None)

    return parsed_sql


def replace_literals(parsed_sql, table_alias_dict, table_column_aliases):

    # print(repr(parsed_sql))

    for literal in parsed_sql.find_all(exp.Literal):

        if not isinstance(literal.parent.this, exp.Column):
            continue

        column_name = literal.parent.this.name
        column_table = literal.parent.this.table
        table_name = (
            table_alias_dict[column_table]
            if column_table in table_alias_dict
            else column_table
        )

        values = table_column_aliases[table_name]["columns"][column_name]["values"]
        for value in values:
            if literal.name != value["literal_value"]:
                continue
            literal_alias = value["alias"]
            literal.replace(
                exp.Literal(this=literal_alias, is_string=literal.is_string)
            )
            break

    # print(repr(parsed_sql))

    return parsed_sql


def match_source_target_columns(src_type, target_info):

    src_type = str(src_type).lower().split("(")[0]

    target_type = target_info["type"]
    target_type = str(target_type).lower().split("(")[0]

    if src_type in [
        "integer",
        "float",
        "double",
        "real",
        "numeric",
        "decimal",
        "bigint",
    ]:
        src_type = "numeric"
    elif src_type in ["text", "char", "varchar"]:
        src_type = "text"
    elif src_type in ["boolean", "date", "datetime", "timestamp", "null"]:
        pass
    else:
        logger.error("Unknown source column type '%s'", src_type)
        return False

    if target_type in ["integer", "float", "double", "real", "numeric"]:
        target_type = "numeric"
    elif target_type in ["text", "date", "boolean"]:
        pass
    else:
        logger.error("Unknown target column type '%s'", target_type)
        return False

    if src_type == target_type:
        if target_type in ["numeric", "date"]:
            if (
                "min" in target_info
                and target_info["min"] is not None
                and "max" in target_info
                and target_info["max"] is not None
            ):
                return True
        elif target_type == "text":
            if "values" in target_info and len(target_info["values"]) > 1:
                return True
        else:
            logger.error("Unknown target column type '%s'", target_type)
            return False

    return False


def qualify_columns(expression, schema=None):
    if schema:
        schema = convert_source_schema(schema)

    expression = optimizer.qualify.qualify(expression)
    return expression


def create_table_col_aliases(schema, ast):

    table_columns = []
    for scope in traverse_scope(ast):
        for c in scope.columns:
            if c.table:
                table_columns.append((c.table.lower(), c.name.lower()))
    table_columns = set(table_columns)

    table_column_aliases = {}
    for table in ast.find_all(exp.Table):
        table_name = table.name.lower()
        if table_name in table_column_aliases:
            continue
        table_column_aliases[table_name] = {
            "alias": "TABLE_" + str(len(table_column_aliases)),
            "columns": {},
        }

    for table, column in table_columns:
        if column not in table_column_aliases[table]["columns"]:
            data_type = schema[table]["columns"][column.lower()]["type"]
            column_alias = "COLUMN_" + (
                str(len(table_column_aliases[table]["columns"]))
            )
            table_column_aliases[table]["columns"][column] = {
                "alias": column_alias,
                "dataType": data_type,
                "values": [],
            }
    return table_column_aliases


def get_target_question_llm_watsonx(
    llm_client, source_nl_question, source_sql, target_sql, target_schema
):

    messages = [
        {
            "role": "system",
            "content": "Given a source natural language question, its ground truth SQL, a target SQL \
                and a target schema, your TASK is to replace the target variables in the source natural \
                language question and make it into a meaningful and readable question that makes sense \
                in the target domain. When you generate the target natural language question, make sure \
                that it is meaningful for the target domain, and if not, replace the words with a paraphrase \
                so as to have a meaningful target sentence. Make sure not to include variables or names of \
                tables which do not exist. Use the descriptions of the columns to make the sentence more readable, \
                if the column names are obscure in the domain. Do not provide explantions or answers to the generated \
                target question. ",
        },
        {
            "role": "user",
            "content": "Here is the source question: "
            + source_nl_question
            + "\nHere is the ground truth source SQL:"
            + source_sql
            + "\nHere is the schema of the target domain: "
            + str(target_schema)
            + "\nHere is the target SQL:"
            + str(target_sql)
            + "\nGenerated target natural language question: ",
        },
    ]

    response = llm_client.invoke(input=messages)
    # fix reponse which starts with a newline
    if response and response[0] == "\n":
        response = response[1:]
    # remove trailing extra output which some LLMs add to the response
    response = response.split("\n")[0]
    target_question = response.strip()

    return target_question


def get_target_question_llm_rits(
    llm_client, llm_model, source_nl_question, source_sql, target_sql, target_schema
):

    messages = [
        {
            "role": "system",
            "content": "Given a source natural language question, its ground truth SQL, a target SQL \
                and a target schema, your TASK is to replace the target variables in the source natural \
                language question and make it into a meaningful and readable question that makes sense \
                in the target domain. When you generate the target natural language question, make sure \
                that it is meaningful for the target domain, and if not, replace the words with a paraphrase \
                so as to have a meaningful target sentence. Make sure not to include variables or names of \
                tables which do not exist. Use the descriptions of the columns to make the sentence more readable, \
                if the column names are obscure in the domain. Do not provide explantions or answers to the generated \
                target question. ",
        },
        {
            "role": "user",
            "content": "Here is the source question: "
            + source_nl_question
            + "\nHere is the ground truth source SQL:"
            + source_sql
            + "\n Here is the schema of the target domain: "
            + str(target_schema)
            + "\nHere is the target SQL:"
            + str(target_sql),
        },
    ]

    response_format = {
        "type": "json_schema",
        "json_schema": {
            "name": "target_schema",
            "schema": TargetSchema.model_json_schema(),
        },
    }

    response = llm_client.chat.completions.create(
        model=llm_model,
        timeout=20,
        messages=messages,
        response_format=response_format,
    )

    template = response.choices[0].message.content
    target_question = json.loads(template)["target_question"]

    return target_question


def filter_target_schema(target_schema, target_sql):

    # create a list of tables used by the target SQL
    keep_tables = set()
    # we need to convert target sql to string and parse again to find tables!!
    ast = parse_one(str(target_sql))
    for table in ast.find_all(exp.Table):
        # logger.info("Table:%s:%s", type(table), table.name)
        keep_tables.add(table.name.lower())

    # filter target schema to only keep tables used by the target SQL
    filtered_schema = {}
    for k, v in target_schema.items():
        if k in keep_tables:
            filtered_schema[k] = v

    # logger.info("TargetSchema:%s", target_schema.keys())
    logger.info("FilteredSchema:%s", filtered_schema.keys())

    return filtered_schema


def get_target_question_llm(
    llm_client,
    llm_model,
    source_nl_question,
    source_sql,
    target_sql,
    target_schema,
    use_rits,
):

    filtered_schema = filter_target_schema(target_schema, target_sql)
    # filtered_schema = target_schema

    if use_rits:
        target_question = get_target_question_llm_rits(
            llm_client,
            llm_model,
            source_nl_question,
            source_sql,
            target_sql,
            filtered_schema,
        )
    else:  # WatsonX model
        target_question = get_target_question_llm_watsonx(
            llm_client, source_nl_question, source_sql, target_sql, filtered_schema
        )

    return target_question


def find_matching_tables_and_columns(table_column_aliases, target_schema):
    results = {}
    target_table_used = set()
    if len(table_column_aliases) > len(target_schema):
        logger.warning(
            "Mapping '%i' source tables to '%s' target tables, a target table may get reused",
            len(table_column_aliases),
            len(target_schema),
        )

    for _, source_table in table_column_aliases.items():
        table_id = source_table["alias"]

        source_col_types = {}
        for column, column_info in source_table["columns"].items():
            column_id = column_info["alias"]
            column_type = column_info["dataType"]
            source_col_types[column_id] = column_type

        # Look for a matching table in the target schema, search randomly
        target_table_items = list(target_schema.items())
        # print([item[0] for item in target_table_items])
        # note: sometimes random.shuffle() does not shuffle small lists, repeat in a loop
        for _ in range(1, 100):
            random.shuffle(target_table_items)
        # print([item[0] for item in target_table_items])
        for target_table_name, target_table in target_table_items:
            # if we have enough target tables, don't reuse tables
            if (
                len(table_column_aliases) <= len(target_schema)
                and target_table_name in target_table_used
            ):
                continue

            available_target_columns = {
                col_name: col_info
                for col_name, col_info in target_table["columns"].items()
            }

            # Check if each source parameter type has a corresponding column in the target table
            matching_columns = {}
            for source_column, source_type in source_col_types.items():
                # choose target columns randomly
                items = list(available_target_columns.items())
                random.shuffle(items)
                for target_column, target_info in items:
                    if match_source_target_columns(source_type, target_info):
                        matching_columns[source_column] = (target_column, target_info)
                        del available_target_columns[target_column]
                        break

            # Add to results if the target table has columns matching all source parameter types
            if len(matching_columns) >= len(source_col_types):
                if target_table_name in target_table_used:
                    logger.warning(
                        "Reusing target table '%s' for mapping to a source table",
                        target_table_name,
                    )
                results[table_id] = (target_table_name, matching_columns)
                target_table_used.add(target_table_name)
                break

        if table_id not in results:
            raise ValueError(
                f"Cannot find a target table for source table '{table_id}' skipping..."
            )

    return results


def map_db_column_type(db_type: str):

    if db_type.startswith(("char", "varchar", "text", "null")):
        return "text"
    if db_type in ("integer", "bigint"):
        return "integer"
    if db_type.startswith(("float", "double", "real", "numeric", "decimal")):
        return "real"
    if db_type.startswith(("date", "timestamp", "datetime")):
        return "date"

    return db_type


def find_matching_tables_and_columns_service(
    table_column_aliases, source_schema, target_graph
):

    entities = []
    relationships = []

    for src_table, tdata in table_column_aliases.items():
        temp_table = tdata["alias"]
        table_entity = {"id": temp_table, "name": src_table, "type": "table"}
        entities.append(table_entity)

        src_columns = tdata["columns"]
        for src_column, cdata in src_columns.items():
            temp_column = cdata["alias"]
            col_type = cdata["dataType"].lower()
            col_type = map_db_column_type(col_type)
            temp_colid = f"{temp_table}.{temp_column}"
            col_entity = {
                "dataType": col_type,
                "id": temp_colid,
                "name": src_column,
                "type": "column",
            }
            entities.append(col_entity)

            # add column to parent table relationship
            relationship = {
                "relationship": "parent",
                "source": temp_colid,
                "target": temp_table,
            }
            relationships.append(relationship)

            # add literal values for the column, if any
            for value in cdata["values"]:
                temp_literal_id = f"{temp_colid}.{value['alias']}"
                literal_entity = {
                    "dataType": col_type,
                    "id": temp_literal_id,
                    "name": value["literal_value"],
                    "type": "value",
                }
                entities.append(literal_entity)

                # add literal to parent relationship
                relationship = {
                    "relationship": "parent",
                    "source": temp_literal_id,
                    "target": temp_colid,
                }
                relationships.append(relationship)

        # add foreign key relationships e.g.
        #    relationship = {
        #        "relationship": "foreignKey",
        #        "source": "TABLE_0.COLUMN_0",
        #        "target": "TABLE_1.COLUMN_0" }

        src_tableschema = source_schema[src_table]

        fkeys = src_tableschema["foreign_keys"]
        for fkey in fkeys:
            # force lowercase column names !!
            local_columns = [k.lower() for k in fkey["constrained_columns"]]
            foreign_table = fkey["referred_table"].lower()
            foreign_columns = [k.lower() for k in fkey["referred_columns"]]

            #  TODO: for now, only handle single column keys
            if len(local_columns) > 1:
                logger.error(
                    "Cannot process multicolumn key '%s' in table '%s'",
                    local_columns,
                    src_table,
                )
                continue
            local_columns = local_columns[0]

            if len(foreign_columns) > 1:
                logger.error(
                    "Cannot process multicolumn foreign key '%s' in table '%s'",
                    foreign_columns,
                    foreign_table,
                )
                continue
            foreign_columns = foreign_columns[0]

            # print(local_columns, foreign_table, foreign_columns)
            if local_columns not in src_columns:
                continue  # key col not used in current query

            # skip fkey relationship if the foreign table is not being used
            if foreign_table not in table_column_aliases:
                continue

            foreign_table_data = table_column_aliases[foreign_table]

            if foreign_columns not in foreign_table_data["columns"]:
                logger.error(
                    "Cannot find foreign key column '%s' in foreign table '%s'",
                    foreign_columns,
                    foreign_table,
                )
                continue

            local_temp_column = src_columns[local_columns]["alias"]
            local_temp_colid = f"{temp_table}.{local_temp_column}"

            foreign_temp_column = foreign_table_data["columns"][foreign_columns][
                "alias"
            ]
            foreign_temp_table = foreign_table_data["alias"]
            foreign_temp_colid = f"{foreign_temp_table}.{foreign_temp_column}"

            relationship = {
                "relationship": "foreignKey",
                "source": local_temp_colid,
                "target": foreign_temp_colid,
            }

            relationships.append(relationship)

    abstract_graph = {"entities": entities, "relationships": relationships}

    service_request = {
        "instances": 1,
        "generate_values": True,
        "useForeignKeyInference": False,
        "abstract_graph": abstract_graph,
    }

    logger.info("service_request:%s", json.dumps(service_request))
    # send the target abstract graph to the service also
    with open(target_graph, "r", encoding="utf-8") as f:
        graph = json.load(f)
        service_request["target_domain_abstract_graph"] = graph
    # logger.info("service_request:%s", json.dumps(service_request))

    # call the service
    
    response = requests.post(
        url="<redacted for anonymity>",
        json=service_request,
        timeout=10,
    )

    if response.status_code != 200:
        raise ValueError(
            f"Service request failed with status code '{response.status_code}'"
        )

    result = response.json()

    if len(result) == 0:
        raise ValueError("Service request returned an empty result")

    result = result[0]
    if not result["success"]:
        raise ValueError("Service request was not successful")

    results = result["abstract_to_target"]


    return results


def create_target_query_from_service_results(sql_template, template_to_target_dict):
 
    # update literals before updating columns
    for literal in sql_template.find_all(exp.Literal):
        # don't replace stand alone literals e.g. 'LIMIT 1' shown below
        # "SELECT TABLE_0.COLUMN_1, TABLE_0.COLUMN_0 FROM TABLE_0 ORDER BY TABLE_0.COLUMN_2 LIMIT 1"
        if not isinstance(literal.parent.this, exp.Column):
            continue

        column_name = literal.parent.this.name
        column_table = literal.parent.this.table
        literal_id = f"{column_table}.{column_name}.{literal.alias_or_name}"

        if literal_id not in template_to_target_dict:
            raise ValueError(
                f"Cannot find literal '%literal_id' in template_to_target_dict ..."
            )

        target_data = template_to_target_dict[literal_id]
        logger.debug("mapped_literal:%s", target_data)
        target_name = target_data["target_name"]
        target_type = target_data["dataType"]

        if target_type == "text":
            literal.replace(exp.Literal(this=str(target_name), is_string=True))
        elif target_type == "integer":
            # note: Literal value must be passed as a string but is_string=False
            literal.replace(exp.Literal(this=str(target_name), is_string=False))
        elif target_type == "real":
            # note: Literal value must be passed as a string but is_string=False
            literal.replace(exp.Literal(this=str(target_name), is_string=False))
        elif target_type == "date":
            # note: Literal value must be passed as a string and is_string=True
            literal.replace(exp.Literal(this=str(target_name), is_string=True))
        else:
            logger.error(
                "Unknown literal type '%s' for literal '%s', column '%s'",
                target_type,
                target_name,
                target_data["target_column"],
            )
            continue

    # update columns before updating tables
    for column in sql_template.find_all(exp.Column):
        table_name = column.table
        column_id = f"{table_name}.{column.name}"

        if table_name == "":
            raise ValueError(
                f"Empty table name for column '{column_id}' in create_target_query..."
            )

        if column_id not in template_to_target_dict:
            logger.error("Cannot find column '%s' in create_target_query...", column_id)
            continue
        target_col_name = template_to_target_dict[column_id]["target_name"].lower()

        if "target_table" not in template_to_target_dict[column_id]:
            msg = f"Service response missing 'target_table' key for target column '{column_id}': {template_to_target_dict[column_id]}"
            raise ValueError(msg)

        target_table = template_to_target_dict[column_id]["target_table"].lower()

        column.replace(exp.Column(this=target_col_name, table=target_table))


    for table in sql_template.find_all(exp.Table):
        table_id = table.this

        if table_id not in template_to_target_dict:
            logger.error("Cannot find table '%s' in create_target_query...", table.name)
            continue

        target_id = template_to_target_dict[table_id]["target_id"].lower()
        table.replace(target_id)

    return sql_template


def create_target_query(sql_template, template_to_target_dict):
    # update literals before updating columns
    for literal in sql_template.find_all(exp.Literal):
        # don't replace stand alone literals e.g. 'LIMIT 1' shown below
        # "SELECT TABLE_0.COLUMN_1, TABLE_0.COLUMN_0 FROM TABLE_0 ORDER BY TABLE_0.COLUMN_2 LIMIT 1"
        if not isinstance(literal.parent.this, exp.Column):
            continue

        column_name = literal.parent.this.name
        column_table = literal.parent.this.table
        if column_table not in template_to_target_dict:
            logger.error(
                "Cannot find table '%s' in template_to_target mapping", column_table
            )
            continue

        target_table_name = template_to_target_dict[column_table][0]
        matching_cols = template_to_target_dict[column_table][1]
        target_col_name = matching_cols[column_name][0]
        target_col_info = matching_cols[column_name][1]
        target_col_type = target_col_info["type"]
        if target_col_type == "text":
            values = [v for v in target_col_info["values"] if v is not None]
            # TODO: for now, just pick a value randomly from the known values
            target_value = random.choice(values)
            literal.replace(exp.Literal(this=str(target_value), is_string=True))
        elif target_col_type == "numeric":
            target_min = target_col_info["min"]
            target_max = target_col_info["max"]
            #  TODO: for now, just use the mid value
            target_value = random.randint(int(target_min), int(target_max))
            # literal.replace(exp.Literal(this=target_value, is_string=False))
            # note: Literal value must be string
            literal.replace(exp.Literal(this=str(target_value), is_string=False))
        elif target_col_type == "date":
            target_min = target_col_info["min"]
            target_max = target_col_info["max"]
            #  TODO: for now, just use the max date
            target_value = target_max
            # literal.replace(exp.Literal(this=target_value, is_string=False))
            # note: Literal value must be string
            literal.replace(exp.Literal(this=str(target_value), is_string=True))
        else:
            logger.error(
                "Unknown type '%s' for target column '%s' in table '%s'",
                target_col_type,
                target_col_name,
                target_table_name,
            )
            continue

    """
     # update columns before updating tables
    for scope in traverse_scope(sql_template):
        for column in scope.columns:
            # hack to skip sqlglot columns which are actualy literals e.g. "jetblue airways" in sql below
            # SELECT airlines.country AS country FROM airlines AS airlines WHERE airlines.airline = "jetblue airways"
            if not isinstance(scope.sources.get(column.table), exp.Table):
                continue

            table_name = column.table
            if table_name in template_to_target_dict:
                target_table_name = template_to_target_dict[table_name][0]
                matching_cols = template_to_target_dict[table_name][1]
                target_col_name = matching_cols[column.name][0]
                column.replace(exp.Column(this=target_col_name, table=target_table_name))
            else:
                logger.error("ERROR: Cannot find table '%s' in template_to_target mapping", table_name)
    """

    # print(sql_template)

    # update columns before updating tables
    for column in sql_template.find_all(exp.Column):

        table_name = column.table
        if table_name in template_to_target_dict:
            target_table_name = template_to_target_dict[table_name][0]
            matching_cols = template_to_target_dict[table_name][1]
            target_col_name = matching_cols[column.name][0]
            column.replace(exp.Column(this=target_col_name, table=target_table_name))
        else:
            raise ValueError(
                f"Cannot find table '{table_name}' in template_to_target mapping"
            )

    # print(sql_template)

    for table in sql_template.find_all(exp.Table):
        if table.this in template_to_target_dict:
            target_table_name = template_to_target_dict[table.this][0]
            table.replace(target_table_name)
        else:
            logger.error(
                "ERROR: Cannot find table '%s' in template_to_target mapping",
                table.name,
            )

    # print(sql_template)

    return sql_template


def fix_orderby_sql(ast):
    print(repr(ast))
    for ordered in ast.find_all(exp.Ordered):
        print(ordered.parent)
        print(ordered.parent.parent)
        gp = ordered.parent.parent
        if isinstance(gp, exp.Select):
            for e in gp.expressions:
                if isinstance(ordered.this, exp.Column):
                    print(ordered.this)
                    print(e.alias)


def fix_source_sql(schema, sql, dialect):

    logger.debug("Original source sql: '%s'", sql)

    if schema:
        schema = convert_source_schema(schema)

    ast = parse_one(sql, dialect=dialect)
    logger.debug("Parsed source sql: '%s'", str(ast))

    ast = qualify(ast, schema=schema, dialect=dialect, validate_qualify_columns=False)
    logger.debug("ast:%s", ast)
    # ast = fix_orderby_sql(ast)

    # replace all column table-aliases with actual table-names
    # this should remove the need for scope sensitive transformations
    for scope in traverse_scope(ast):
        table_alias_dict = {}
        for table in scope.tables:
            table_alias_dict[table.alias_or_name.lower()] = table.name.lower()

        for column in scope.columns:
            if column.table:
                col_table = table_alias_dict[column.table.lower()]
                column.replace(exp.Column(this=column.name, table=col_table))

    # print(ast)

    # fix double quoted string literals
    for column in ast.find_all(exp.Column):
        # we only try to fix double quoted string literals in SQL as shown below:
        # "SELECT Country FROM AIRLINES WHERE Airline  =  "JetBlue Airways"
        if not isinstance(column.parent, exp.Predicate):
            continue

        if column.table:
            # if column.table in schema and column.name in schema[column.table]["columns"]:
            if (
                column.table in schema and column.name in schema[column.table]
            ):  # sqlglot style schema
                continue

        logger.warning(
            'Fixing double quoted literal "%s" in source SQL to single quotes:"%s"',
            column.name,
            sql,
        )
        column.replace(exp.Literal(this=column.name, is_string=True))

    logger.debug("Fixed source sql: '%s'", str(ast))
    return ast


def create_llm_client_rits(config):

    for key in (
        "RITS_SERVED_MODEL_NAME",
        "RITS_MODEL_INFERENCE_ENDPOINT",
        "RITS_API_KEY",
    ):
        if key not in config:
            logger.error("%s not found in config", key)
            return

    api_key = config["RITS_API_KEY"]
    model_name = config["RITS_SERVED_MODEL_NAME"]
    model_inference_endpoint = config["RITS_MODEL_INFERENCE_ENDPOINT"]

    if api_key is None:
        logger.error("RITS_API_KEY not found in config")
        return

    if model_name is None:
        logger.error("RITS_SERVED_MODEL_NAME not found in config")
        return

    if model_inference_endpoint is None:
        logger.error("RITS_MODEL_INFERENCE_ENDPOINT not found in config")
        return

    # Very Important: The ending '/v1' is required for 'completions', 'models'..., but not for 'health'!
    base_url = f"{model_inference_endpoint}/v1"

    logger.info("Using RITS LLM model '%s' for question generation", model_name)

    client = OpenAI(
        api_key=api_key, base_url=base_url, default_headers={"RITS_API_KEY": api_key}
    )

    return client, model_name


def create_llm_client_watsonx(config):

    WATSONX_PROJECT_ID = config["WATSONX_PROJECT_ID"]
    WATSONX_APIKEY = config["WATSONX_APIKEY"]
    WATSONX_URL = config["WATSONX_URL"]
    WATSONX_MODEL_ID = config["WATSONX_MODEL_ID"]

    logger.info(
        "Using WatsonX LLM model '%s' for question generation", WATSONX_MODEL_ID
    )

    #  WatsonX LLM model parameters
    params = {
        GenTextParamsMetaNames.DECODING_METHOD: "greedy",
        # GenTextParamsMetaNames.DECODING_METHOD: "sample",
        GenTextParamsMetaNames.MAX_NEW_TOKENS: 3000,
        GenTextParamsMetaNames.TEMPERATURE: 0.05,
        GenTextParamsMetaNames.TOP_P: 0.1,
        GenTextParamsMetaNames.RANDOM_SEED: 10,
        # GenTextParamsMetaNames.RANDOM_SEED: i,
        GenTextParamsMetaNames.REPETITION_PENALTY: 1.1,
        GenTextParamsMetaNames.TIME_LIMIT: 5,
    }

    # create LLM
    llm = WatsonxLLM(
        model_id=WATSONX_MODEL_ID,
        url=WATSONX_URL,
        apikey=WATSONX_APIKEY,
        project_id=WATSONX_PROJECT_ID,
        params=params,
    )

    return llm


def create_llm_client_openai(config):
    model_name = config["OPENAI_MODEL_NAME"]
    client = OpenAI(api_key=config["OPENAI_API_KEY"])

    return client, model_name


def get_database_dir(questions_file):

    questions_dir = os.path.dirname(questions_file)
    # check for SPIDER default setup
    database_dir = os.path.join(questions_dir, "database")
    if os.path.isdir(database_dir):
        return database_dir

    # check for BIRD setup
    filename = os.path.basename(questions_file)
    if "dev" in filename:
        database_dir = os.path.join(questions_dir, "dev_databases")
    elif "train" in filename:
        database_dir = os.path.join(questions_dir, "train_databases")
    else:
        logger.error(
            "Cannot determine database directory name from '%s'", questions_file
        )
        return None

    if os.path.isdir(database_dir):
        return database_dir

    logger.error("Cannot find database directory '%s'", database_dir)
    return None


def run(
    config,
    target_graph,
    spiderfile,
    birdfile,
    results_file,
    startrow=None,
    endrow=None,
    rowindex=None,
    service=True,
    use_rits=False,
):

    if spiderfile is not None:
        questions_file = spiderfile
    else:
        questions_file = birdfile

    # randomness used in matching target tables/cols
    random.seed(30)

    if use_rits:
        llm_client, llm_model = create_llm_client_rits(config=config)
    else:
        llm_client = create_llm_client_watsonx(config=config)
        llm_model = None

    skipped_db_names = [
        #    "student_transcripts_tracking",
        #    "tvshow",
        #    "world_1",
        #    "orchestra",
        #    "perpetrator",
        #    "tracking_grants_for_research",
        #    "aircraft",
        #    "railway",
    ]

    source_database_dir = get_database_dir(questions_file)
    if source_database_dir is None:
        logger.error("Cannot determine source database directory, exiting...")
        return
    logger.info("Using source databse directory '%s'", source_database_dir)

    # load target metadata
    # target_schema = convert_target_graph(target_graph, keep_entities = ["asset", "workorder", "sr"])
    target_schema = convert_target_graph(target_graph)

    # print(json.dumps(target_schema, indent=2))

    results = []

    # Go through questions
    logger.info("Reading questions file '%s'", questions_file)
    questions = json.load(open(questions_file, "r", encoding="utf-8"))

    if service:
        logger.info("Using remote service for table/col/value mapping")

    source_db_name = None
    source_db_index = None
    for index, question in enumerate(questions, 1):

        # keep a per db question index also
        db_name = question["db_id"]
        if not source_db_name or source_db_name != db_name:
            source_db_name = db_name
            source_db_index = 1
            logger.info("Processing questions for source database '%s'", source_db_name)
        else:
            source_db_index += 1

        if startrow and startrow > index:
            continue
        if endrow and endrow < index:
            break

        # only handle the specified query
        if rowindex and rowindex > index:
            continue
        if rowindex and rowindex < index:
            break

        if source_db_name in skipped_db_names:
            logger.info(
                "Skipping question number '%i', from '%s' database",
                index,
                source_db_name,
            )
            continue

        try:
            source_nl_question = question["question"]
            source_sql_query = question["query" if spiderfile is not None else "SQL"]

            logger.info(
                "______________________________________________________________________________________"
            )
            logger.info(
                "Processing question number '%i', from '%s/%d' database",
                index,
                source_db_name,
                source_db_index,
            )
            logger.info("Source Question: %s", source_nl_question)
            logger.info("Source SQL: %s", source_sql_query)

            # load metadata of source domain
            source_db_dialect = "sqlite"
            source_db_uri = f"sqlite:///{source_database_dir}/{source_db_name}/{source_db_name}.sqlite"
            source_schema = load_src_schema(source_db_uri)

            table_column_aliases, sql_template = extract_sql_template(
                source_schema, source_sql_query, source_db_dialect
            )

            if service:
                # use table/col mapping service
                template_to_target_dict = find_matching_tables_and_columns_service(
                    table_column_aliases, source_schema, target_graph
                )
                if not template_to_target_dict:
                    continue
                target_sql_query = create_target_query_from_service_results(
                    sql_template, template_to_target_dict
                )
            else:
                template_to_target_dict = find_matching_tables_and_columns(
                    table_column_aliases, target_schema
                )
                target_sql_query = create_target_query(
                    sql_template, template_to_target_dict
                )

            logger.info("Target SQL: %s", target_sql_query)

            target_question = get_target_question_llm(
                llm_client,
                llm_model,
                source_nl_question,
                source_sql_query,
                target_sql_query,
                target_schema,
                use_rits,
            )

            logger.info("Target Question: %s", target_question)

            result = {
                "SourceIndex": index,
                "SourceDBIndex": source_db_index,
                "SourceDB": source_db_name,
                "SourceQuestion": source_nl_question,
                "SourceSQL": source_sql_query,
                "TargetQuestion": target_question,
                "TargetSQL": target_sql_query,
                "EvalSelect": 0,
            }
            results.append(result)

        except Exception as ex:
            logger.error("Error processing question '%i', '%s'", index, str(ex))

    df = pd.DataFrame.from_dict(results)

    if results_file:
        logger.info("Writing results file '%s'", results_file)
        df.to_csv(results_file, index=False)


def parse_args(argv=None):
    """Command line options."""
    program_name = "convert_domain.py"
    program_desc = (
        "Convert NL2SQL datasets from src domain to target domain using sqlglot and LLM"
    )

    if argv is None:
        argv = sys.argv[1:]

    # setup option parser
    parser = argparse.ArgumentParser(prog=program_name, description=program_desc)

    # Inputs
    parser.add_argument(
        "-e",
        "--env",
        type=str,
        default=None,
        help="Env file to load settings/credentials, default(.env)",
    )

    # --spider and --bird options are mutually exclusive
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        "--spider-file",
        "-s",
        type=str,
        dest="spiderfile",
        default=None,
        help="Input SPIDER questions file.",
    )

    group.add_argument(
        "--bird-file",
        "-b",
        type=str,
        dest="birdfile",
        default=None,
        help="Input BIRD questions file.",
    )

    parser.add_argument(
        "--target-graph",
        "-g",
        type=str,
        dest="graphfile",
        default="abstract_graph.json",
        required=True,
        help="Traget domain Abstract Graph file e.g. abstract_graph.json",
    )

    # settings
    parser.add_argument(
        "--start",
        type=int,
        dest="startrow",
        default=None,
        help="Only process questions starting from the specifed row",
    )

    parser.add_argument(
        "--end",
        type=int,
        dest="endrow",
        default=None,
        help="Last question row to process",
    )

    parser.add_argument(
        "--row",
        type=int,
        dest="row",
        default=None,
        help="Only process question for the specified row",
    )

    parser.add_argument(
        "--no-service",
        action="store_false",
        dest="service",
        help="Don't use table/col mapping service",
    )

    parser.add_argument(
        "--use-rits",
        action="store_true",
        dest="rits",
        help="Use LLM from  RITS platform (default: WatsonX)",
    )

    # Outputs
    parser.add_argument(
        "--results-file",
        "-r",
        type=str,
        dest="results_file",
        required=True,
        help="Domain mapped results CSV file",
    )

    args = parser.parse_args()

    return args


def main(argv=None):

    # logging.basicConfig(
    #    level=logging.INFO,
    #    format="%(levelname)s:%(name)s:%(asctime)s: %(message)s",
    #    datefmt="%Y-%m-%d %H:%M:%S",
    # )

    logging.basicConfig(
        level=logging.INFO,
        format="%(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # suppress INFO messages for some very chatty packages
    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
    for lgr in loggers:
        if lgr.name in (
            "httpx",
            #<redacted for anonymity>
        ):
            # print(lgr.name, lgr.level)
            lgr.setLevel(logging.ERROR)

    args = parse_args(argv)
    config = dotenv_values(args.env)

    run(
        config,
        args.graphfile,
        args.spiderfile,
        args.birdfile,
        args.results_file,
        startrow=args.startrow,
        endrow=args.endrow,
        rowindex=args.row,
        service=args.service,
        use_rits=args.rits,
    )


if __name__ == "__main__":
    sys.exit(main())
