import argparse
import csv
import logging
import os

import numpy as np
import pandas as pd
from scipy.stats import permutation_test

from src.data_conversion import load_truth_file, load_prediction_file
from src.evaluation import calculate_score
from src.utility import Status

logger = logging.getLogger(__name__)

keywords = {
    "are missing",
    "do not",
    "need to",
    "please provide",
    "is require",
    "cannot be fulfil",
    "none of",
    "would you like",
    "i need",
    "lacks",
    "i cannot",
    "requires the",
    "is missing",
    "could you please",
    "no function available",
    "does not support",
    "function cannot",
    "please supply",
    "missing required",
    "not provided",
    "no available",
    "should i",
    "m missing:",
    "do you want",
    "i can only",
    "i can't",
    "isn't applicable",
    "what's your",
    "i also need",
    "cannot call",
}


def is_likely_tool_call(text: str) -> bool:
    for keyword in keywords:
        if keyword in text.lower():
            return False
    return True


def patch_status(text: str, status: Status) -> Status:
    if status == Status.NoPrediction and text is not None and len(text) > 0:
        status = Status.SchemaNotMatch
    if status == Status.SchemaNotMatch:
        if not is_likely_tool_call(text):
            return Status.ReasoningFailure
    return status


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--truth_file", required=True, type=str, help="Truth file")
    parser.add_argument(
        "--prediction_file_control",
        required=True,
        type=str,
        help="Prediction file control",
    )
    parser.add_argument(
        "--prediction_file_treatment",
        required=True,
        type=str,
        help="Prediction file treatment",
    )
    parser.add_argument(
        "--param_similarity_threshold",
        default=0.9,
        type=float,
        help="Parameter value similarity threshold.",
    )
    parser.add_argument(
        "--output_report", required=True, type=str, help="Report tsv file"
    )
    parser.add_argument("--output_summary", type=str, help="Report summary file")
    parser.add_argument("--device", type=str, help="Device to use")

    parsed_args = parser.parse_args()
    return parsed_args


def main(args):
    """Generate side-by-side report."""
    print(f"Running {os.path.basename(__file__)} with args: {args}")
    truth = load_truth_file(args.truth_file)
    prediction_control = load_prediction_file(args.prediction_file_control)
    prediction_treatment = load_prediction_file(args.prediction_file_treatment)
    scores_control, reports_control, per_sample_control = calculate_score(
        truth,
        prediction_control,
        args.param_similarity_threshold,
        args.device,
    )
    scores_treatment, reports_treatment, per_sample_treatment = calculate_score(
        truth,
        prediction_treatment,
        args.param_similarity_threshold,
        args.device,
    )

    control_dict = {x.id: x for x in reports_control}
    treatment_dict = {x.id: x for x in reports_treatment}

    assert (
        control_dict.keys() == treatment_dict.keys()
    ), f"control and treatment have different keys:[{control_dict.keys()}][{treatment_dict.keys()}]"
    logger.info(f"{scores_control=}")
    logger.info(f"{scores_treatment=}")

    pval_api_p = permutation_test(
        (per_sample_control.api_p, per_sample_treatment.api_p),
        statistic=lambda x, y: np.mean(x - y),
        permutation_type="samples",
    )
    pval_api_r = permutation_test(
        (per_sample_control.api_r, per_sample_treatment.api_r),
        statistic=lambda x, y: np.mean(x - y),
        permutation_type="samples",
    )
    pval_api_f = permutation_test(
        (per_sample_control.api_f1, per_sample_treatment.api_f1),
        statistic=lambda x, y: np.mean(x - y),
        permutation_type="samples",
    )
    pval_param_p = permutation_test(
        (per_sample_control.param_p, per_sample_treatment.param_p),
        statistic=lambda x, y: np.mean(x - y),
        permutation_type="samples",
    )
    pval_param_r = permutation_test(
        (per_sample_control.param_r, per_sample_treatment.param_r),
        statistic=lambda x, y: np.mean(x - y),
        permutation_type="samples",
    )
    pval_param_f = permutation_test(
        (per_sample_control.param_f1, per_sample_treatment.param_f1),
        statistic=lambda x, y: np.mean(x - y),
        permutation_type="samples",
    )

    data = []
    df_head = ["Control", "Treatment"]
    df_idx = list(scores_control.keys())
    for key in df_idx:
        data.append(
            [
                scores_control[key],
                scores_treatment[key],
            ]
        )
    df = pd.DataFrame(data, df_idx, df_head)

    df["Delta"] = df.apply(
        lambda row: (
            (row["Treatment"] - row["Control"])
            if isinstance(row["Control"], (int, float))
            and isinstance(row["Treatment"], (int, float))
            else np.nan
        ),
        axis=1,
    )

    df["P-value"] = np.nan

    df.at["marco_averaged_P_api", "P-value"] = pval_api_p.pvalue
    df.at["marco_averaged_R_api", "P-value"] = pval_api_r.pvalue
    df.at["marco_averaged_F1_api", "P-value"] = pval_api_f.pvalue

    df.at["marco_averaged_P_param", "P-value"] = pval_param_p.pvalue
    df.at["marco_averaged_R_param", "P-value"] = pval_param_r.pvalue
    df.at["marco_averaged_F1_param", "P-value"] = pval_param_f.pvalue

    print(df)
    if args.output_summary is not None:
        df.to_csv(args.output_summary, sep="\t")

    prediction_control_raw_api = load_prediction_file(
        args.prediction_file_control, True
    )
    prediction_treatment_raw_api = load_prediction_file(
        args.prediction_file_treatment, True
    )

    if ".api" in args.prediction_file_treatment:
        treatment_raw_output = args.prediction_file_treatment.replace(".api", "")
        prediction_treatment_raw_string = load_prediction_file(
            treatment_raw_output, intermediate_result=True
        )
    else:
        prediction_treatment_raw_string = {}

    with open(args.output_report, "w", newline="") as fout:
        head_fields = [
            "id",
            "control_status",
            "treatment_status",
            "origin_prompt",
            "origin_expected_output",
            "truth_api",
            "prediction_control_raw_api",
            "prediction_treatment_raw_api",
            "prediction_treatment_raw_string",
            "control_predict_api",
            "treatment_predict_api",
            "truth_param",
            "control_predict_param",
            "treatment_predict_param",
            "truth_param_value",
            "control_predict_param_value",
            "treatment_predict_param_value",
        ]
        fout.write("\t".join(head_fields) + "\n")
        tsv_writer = csv.writer(fout, delimiter="\t", lineterminator="\n")
        for id in control_dict.keys():
            control = control_dict[id]
            treatment = treatment_dict[id]

            origin_prompt = control.truth.get_raw_prompt()
            origin_expected_output = control.truth.output

            truth_api = [x.api_name for x in control.truth.api_call.api_calls]
            truth_param = []
            truth_param_value = []
            for x in control.truth.api_call.api_calls:
                truth_param.append({x.api_name: list(x.params.keys())})
                truth_param_value.append({x.api_name: list(x.params.values())})

            control_raw_prediction = (
                prediction_control_raw_api[id].response
                if id in prediction_control_raw_api
                else ""
            )
            treatment_raw_prediction = (
                prediction_treatment_raw_api[id].response
                if id in prediction_treatment_raw_api
                else ""
            )
            treatment_raw_prediction_string = (
                prediction_treatment_raw_string[id].response
                if id in prediction_treatment_raw_string
                else ""
            )

            control_predict_api = (
                [x.api_name for x in control.prediction] if control.prediction else []
            )
            control_predict_param = []
            control_predict_param_value = []
            if control.prediction is not None:
                for x in control.prediction:
                    control_predict_param.append({x.api_name: list(x.params.keys())})
                    control_predict_param_value.append(
                        {x.api_name: list(x.params.values())}
                    )

            treatment_predict_api = (
                [x.api_name for x in treatment.prediction]
                if treatment.prediction
                else []
            )
            treatment_predict_param = []
            treatment_predict_param_value = []
            if treatment.prediction is not None:
                for x in treatment.prediction:
                    treatment_predict_param.append({x.api_name: list(x.params.keys())})
                    treatment_predict_param_value.append(
                        {x.api_name: list(x.params.values())}
                    )

            control_status = patch_status(control_raw_prediction, control.status)
            treatment_status = patch_status(
                treatment_raw_prediction_string, treatment.status
            )

            values = [
                id,
                control_status,
                treatment_status,
                origin_prompt,
                origin_expected_output,
                truth_api,
                control_raw_prediction,
                treatment_raw_prediction,
                treatment_raw_prediction_string,
                control_predict_api,
                treatment_predict_api,
                truth_param,
                control_predict_param,
                treatment_predict_param,
                truth_param_value,
                control_predict_param_value,
                treatment_predict_param_value,
            ]
            tsv_writer.writerow(values)


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s][%(levelname)s][%(name)s] %(message)s",
        datefmt="%Y%m%d %H:%M:%S",
    )
    args = parse_args()
    main(args)
