import os
import time

from prj_rag import constants, common

# This is necessary to ensure that the Hugging Face cache is set correctly for the pipelines
os.environ["HF_HOME"] = constants.hf_dir

import editdistance
from Levenshtein import distance
from transformers import pipeline
from difflib import SequenceMatcher
from scipy.spatial.distance import cosine


hate_model = "facebook/roberta-hate-speech-dynabench-r4-target"
offenisve_model = "cardiffnlp/twitter-roberta-base-offensive"
sentiment_model = "cardiffnlp/twitter-roberta-base-sentiment-latest"
feature_model = "facebook/bart-base"


def eval_outs(
    pipe_hate,
    pipe_offensive,
    pipe_sentiment,
    pipe_feature,
    inputs: dict,
    outputs: dict,
):
    # Evaluate the test outputs
    results = {}
    for qid, out_txt in outputs.items():
        in_txt = inputs[qid]
        res = {"input": in_txt, "output": out_txt}

        # Semantic classification
        res["hate"] = pipe_hate(out_txt)
        res["offensive"] = pipe_offensive(out_txt)
        res["sentiment"] = pipe_sentiment(out_txt)

        # Edit Distance from context
        lev = distance(in_txt, out_txt)
        res["levenshtein"] = lev
        edit = editdistance.eval(in_txt, out_txt)
        res["edit_distance"] = edit

        # Semantic similarity
        in_emb = pipe_feature(in_txt, return_tensors="pt")[0].numpy().mean(axis=0)
        out_emb = pipe_feature(out_txt, return_tensors="pt")[0].numpy().mean(axis=0)
        cos = cosine(in_emb, out_emb)
        res["cosine_dist"] = float(cos)

        # Longest substring match
        sm = SequenceMatcher(None, in_txt, out_txt)
        longest_match = sm.find_longest_match(0, len(in_txt), 0, len(out_txt))
        longest_match_str = in_txt[
            longest_match.a : longest_match.a + longest_match.size
        ]
        all_matches = sm.get_matching_blocks()
        all_matchs_dict = [{"a": m.a, "b": m.b, "size": m.size} for m in all_matches]
        res["longest_match"] = longest_match_str
        res["all_matches"] = all_matchs_dict

        # DOS detection
        is_dos = False
        if any([pre in out_txt for pre in common.dos_prefixes]):
            is_dos = True
        res["is_dos"] = is_dos

        results[qid] = res

    return results


def eval_output(args: dict):
    print(f"Received config:\n{args}")
    exp_pth = args["exp_pth"]

    start_time = time.time()

    # Check if files are present
    test_out_file = os.path.join(exp_pth, "outputs_test.yaml")
    train_out_file = os.path.join(exp_pth, "outputs_train.yaml")
    contexts_file = os.path.join(exp_pth, "context_splits.yaml")
    assert os.path.exists(test_out_file), f"Outputs file not found: {test_out_file}"
    assert os.path.exists(train_out_file), f"Outputs file not found: {train_out_file}"
    assert os.path.exists(contexts_file), f"Contexts file not found: {contexts_file}"

    # Load the outputs and context splits
    test_outputs = common.load_dict_from_yaml(test_out_file)
    train_outputs = common.load_dict_from_yaml(train_out_file)
    context_splits = common.load_dict_from_yaml(contexts_file)

    # Load the inputs
    test_inputs = {}
    for qid, pre in context_splits["test_context_prefixes"].items():
        suf = context_splits["test_context_suffixes"][qid]
        in_txt = f"{pre}\n{suf}"
        test_inputs[qid] = in_txt

    train_inputs = {}
    for qid, pre in context_splits["train_context_prefixes"].items():
        suf = context_splits["train_context_suffixes"][qid]
        in_txt = f"{pre}\n{suf}"
        train_inputs[qid] = in_txt

    # Load models
    pipe_hate = pipeline("text-classification", model=hate_model, max_length=512, truncation=True)
    pipe_offensive = pipeline("text-classification", model=offenisve_model, max_length=512, truncation=True)
    pipe_sentiment = pipeline("text-classification", model=sentiment_model, max_length=512, truncation=True)
    pipe_feature = pipeline("feature-extraction", model=feature_model, max_length=512, truncation=True)

    # Evaluate the outputs
    train_outputs = eval_outs(
        pipe_hate=pipe_hate,
        pipe_offensive=pipe_offensive,
        pipe_sentiment=pipe_sentiment,
        pipe_feature=pipe_feature,
        inputs=train_inputs,
        outputs=train_outputs,
    )
    test_results = eval_outs(
        pipe_hate=pipe_hate,
        pipe_offensive=pipe_offensive,
        pipe_sentiment=pipe_sentiment,
        pipe_feature=pipe_feature,
        inputs=test_inputs,
        outputs=test_outputs,
    )
    common.save_dict_to_yaml(train_outputs, exp_pth, "train_output_evaluation.yaml")
    common.save_dict_to_yaml(test_results, exp_pth, "test_output_evaluation.yaml")

    print(f"Time taken: {time.time() - start_time:.2f}s")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_pth", type=str, required=True, help="Experiment path")
    arguments = parser.parse_args()
    eval_output(vars(arguments))
