#!/usr/bin/env python3

"""
    Use LLM to validate TargetQuestion and TargetSQL
"""
import sys
import logging
import argparse

import pandas as pd
from dotenv import dotenv_values
from genai.schema import TextGenerationParameters
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
from langchain_ibm import WatsonxLLM

logger = logging.getLogger(__name__)


def createWatsonxLLM(config, params: TextGenerationParameters):
    """
    Create a Watsonx LLM model
    Args:
        config: configuration settings from env file
        params: the parameters of the model to create
    Returns:
        The newly created model
    """

    project_id = config["WATSONX_PROJECT_ID"]
    api_key = config["WATSONX_APIKEY"]
    api_url = config["WATSONX_URL"]
    model_id = config["WATSONX_MODEL_ID"]

    logger.info("Using WatsonX LLM model '%s' for validation", model_id)

    params = {
        GenTextParamsMetaNames.DECODING_METHOD: params.decoding_method,
        GenTextParamsMetaNames.MAX_NEW_TOKENS: params.max_new_tokens,
        GenTextParamsMetaNames.MIN_NEW_TOKENS: params.min_new_tokens,
        GenTextParamsMetaNames.TEMPERATURE: params.temperature,
        GenTextParamsMetaNames.TOP_K: params.top_k,
        GenTextParamsMetaNames.TOP_P: params.top_p,
        GenTextParamsMetaNames.LENGTH_PENALTY: params.length_penalty,
        GenTextParamsMetaNames.RANDOM_SEED: params.random_seed,
        # GenTextParamsMetaNames.RETURN_OPTIONS: params.return_options,
        GenTextParamsMetaNames.REPETITION_PENALTY: params.repetition_penalty,
        GenTextParamsMetaNames.STOP_SEQUENCES: params.stop_sequences,
        GenTextParamsMetaNames.TIME_LIMIT: params.time_limit,

        GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: params.truncate_input_tokens,
    }

    llm = WatsonxLLM(
        model_id=model_id,
        url=api_url,
        apikey=api_key,
        project_id=project_id,
        params=params,
    )

    return llm


def fill_template(question, sql):
    TEMPLATE = f"""
    You are an evaluation agent. Your task is to evaluate whether the natural language question below accurately depicts the SQL that follows. By your best evaluation, say "Correct" if the NL (natural language) question matches the SQL, and "Incorrect" if they do not match. Make sure to look for the following rules.
    
    Rules:
    1. Anything that is a requirement in the SQL should exist in the natural language question, like specific constraints or conditions that cannot be assumed. It need not be exact. Approximate is fine.
    2. The mentioned column names in the SQL should occur in the natural language question. It need not be exact but even an approximate way in which it is mentioned is alright. It just needs to be referenced in some natural language way.
    3. Do not offer any explanation. Only say "Correct" or "Incorrect"
    4. Do not be too strict on your evaluation. Give some benefit of the doubt to account for some ambiguity.
    5. Do not penalize for an understandable representation of an obscure column or table name. For example, if "sr" is written as service request, consider it correct.
    
    Here are some examples of Evaluations, followed by the actual problem:
    
    ### Example 1:
    Natural Language Question: 
    What is the status of the asset with the smallest asset UID?
    
    SQL:
    SELECT asset.children, asset.status FROM asset ORDER BY asset.assetuid LIMIT 1
    
    Evaluation:
    Incorrect
    
    ### Example 2:
    Natural Language Question:
    How many work orders are there?
    
    SQL:
    SELECT COUNT(*) FROM workorder
    
    Evaluation:
    Correct
    
    ### Actual Problem
    Natural Language Question:
    {question}
    
    SQL:
    {sql}
    
    Is the natural language question Correct or Incorrect for the SQL? Just respond with "Correct" or "Incorrect", with no additional formatting.
    Evaluation: 
    """
    return TEMPLATE


def explanation_template(question, sql):
    TEMPLATE = f"""
    You are an error explainer agent. Your task is: Given a SQL, and a question that is a natural language representation of that sql that is flagged incorrect, provide a reason for why it is flagged incorrect.
    Some reasons may be, for example, a column that is in the SQL is not mentioned in the Natural Language question, or a table that is in the SQL is not mentioned in the natural language question, or a condition that is in the SQL is not represented in the natural languag, etc.
    Please pick from the following categories. Categories:
    - missing_column : when the NL (natural language) question does not include a column name that is mentioned in the SQL query, either in the SELECT or WHERE clause,
    - missing_table : when the NL (natural language) question does not mention the table that is being referenced in the SQL, either in direct or indirect way, 
    - missing_constraint : when there is a constraint, like show only one, or order that is present in the SQL but not mentioned in the NL question, 
    - missing_condition : when there is a condition like greater than or less than or any filtering done in the SQL that is not mentioned in the NL question, 
    - other : any other issue you see why the NL question is an incorrect representation of the SQL, 
    - correct : if the NL question is actually a good representation of the SQL, say 'correct' if it was mislabeled as incorrect.
    
    Do not add any explanation, only respond with the categories that are applicable to this specific evaluation. Only respond with the category type. If more than one category apply as the reason for incorrectness, separate them by comma. If none of the categories apply, feel free to write your own assessment of why it is incorrect. If you evaluate it is actually correct, say "correct".
    
    Here are some examples of explanation category, followed by the actual problem:
    ### Example 1
    Incorrect Natural Language Question:
    What is the status of the asset with the smallest asset UID?
    
    SQL:
    SELECT asset.children, asset.status FROM asset ORDER BY asset.assetuid LIMIT 1
    
    Pick from the above categories for the reason for incorrectness of the natural language question. Category: 
    missing_column
    
    ### Example 2
    Incorrect Natural Language Question:
    How many assets have a total accuracy of 1?
    
    SQL:
    SELECT COUNT(*) FROM asset WHERE asset.pluscsumeu = (SELECT workorder.taskid FROM workorder ORDER BY workorder.wosequence DESC LIMIT 1)
    
    Pick from the above categories for the reason for incorrectness of the natural language question. Category: 
    missing_column, missing_table, missing_condition

    ### Actual Problem:
    Incorrect Natural Language Question:
    {question}
    
    SQL:
    {sql}
    
    Pick from the above categories for the reason for incorrectness of the natural language question. Only respond with the category types, with no additional formatting. Category: 
    """
    return TEMPLATE


def run(config, queryfile, outputfile, startrow=None, endrow=None, rowindex=None):
    """_summary_

    Args:
        config (_type_): _description_
        queryfile (_type_): _description_
        outputfile (_type_): _description_
        startrow (_type_, optional): _description_. Defaults to None.
        endrow (_type_, optional): _description_. Defaults to None.
        rowindex (_type_, optional): _description_. Defaults to None.
    """
    params = TextGenerationParameters(
        decoding_method="greedy",
        max_new_tokens=2000,
        random_seed=2,
        repetition_penalty=1.1,
        temperature=0.05,
        top_p=0.1,
    )
    model = createWatsonxLLM(config, params=params)

    logger.info("Reading queries file '%s'", queryfile)
    df = pd.read_csv(queryfile)

    df["LLMValidation"] = ["NA"] * df.shape[0]
    df["LLMExplanation"] = ["NA"] * df.shape[0]
    validation_col_loc = df.columns.get_loc("LLMValidation")
    explanation_col_loc = df.columns.get_loc("LLMExplanation")

    for i, row in df.iterrows():
        index = i + 1
        if startrow and startrow > index:
            continue
        if endrow and endrow < index:
            break
        # only handle the specified query
        if rowindex and rowindex > index:
            continue
        if rowindex and rowindex < index:
            break

        target_question = row["TargetQuestion"]
        target_sql = row["TargetSQL"]

        logger.info("==========================================================")
        logger.info("Target Question(%d): '%s'", index, target_question)
        logger.info("Target SQL(%d): %s", index, target_sql)

        template = fill_template(target_question, target_sql)
        response = model.invoke(template)
        df.iat[i, validation_col_loc] = response

        logger.info("Question(%d):Correct/Incorrect:%s", index, response)

        if "Correct" not in response:
            template = explanation_template(target_question, target_sql)
            exp = model.invoke(template)
            df.iat[i, explanation_col_loc] = exp
            logger.info("Question(%d):Explanation:%s", index, exp)

    # logger.info("Validation count '%s'", df["LLMValidation"].value_counts())
    logger.info("Writing output file '%s'", outputfile)
    df.to_csv(outputfile, index=False)


def parse_args(argv=None):
    """Command line options."""
    program_name = __name__
    program_desc = "Use LLM to Validate TargetSQL and TargetQuestion"

    if argv is None:
        argv = sys.argv[1:]

    # setup option parser
    parser = argparse.ArgumentParser(prog=program_name, description=program_desc)

    # Inputs
    parser.add_argument(
        "-q",
        "--queryfile",
        type=str,
        dest="queryfile",
        required=True,
        help="Input queries CSV file",
    )

    # settings
    parser.add_argument(
        "--start",
        type=int,
        dest="startrow",
        default=None,
        help="Only process questions starting from the specifed row",
    )

    parser.add_argument(
        "--end",
        type=int,
        dest="endrow",
        default=None,
        help="Last question row to process",
    )

    parser.add_argument(
        "--row",
        type=int,
        dest="row",
        default=None,
        help="Only process question for the specified row",
    )

    parser.add_argument(
        "-e",
        "--env",
        type=str,
        default=".env",
        help="Env file to load settings/credentials, default(.env)",
    )

    # Outputs
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        dest="outputfile",
        required=True,
        help="LLM validated query CSV file",
    )

    args = parser.parse_args()
    return args


def main(argv=None):

    # logging.basicConfig(
    #    level=logging.INFO,
    #    format="%(levelname)s:%(name)s:%(asctime)s: %(message)s",v
    #    datefmt="%Y-%m-%d %H:%M:%S",
    # )

    logging.basicConfig(
        level=logging.INFO,
        format="%(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # suppress INFO messages for some very chatty packages
    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
    for lgr in loggers:
        if lgr.name in (
            "httpx",
            "ibm_watsonx_ai",
            "ibm_watsonx_ai.client",
            "ibm_watsonx_ai.wml_resource",
        ):
            # print(lgr.name, lgr.level)
            lgr.setLevel(logging.ERROR)

    args = parse_args(argv)
    config = dotenv_values(args.env)

    if args.queryfile == args.outputfile:
        logger.error("Input query file '%s' cannot be the same as the outputfile",
                     args.queryfile)
        return 1

    run(config=config, queryfile=args.queryfile, outputfile=args.outputfile,
        startrow=args.startrow, endrow=args.endrow, rowindex=args.row)


if __name__ == "__main__":
    sys.exit(main())
