import os
import sys
import re
import jsonlines
import json
import time

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from module_06_evaluation.prompt_template import *
from utils.model_loader import Model, DS_Model,DS2_Model, lora_Model, VLLM_Model
from utils.prompt_template import user_prompt_template
from reader.llama_reader import llama_reader, llama_reader_batch
from reader.ds_reader import ds_reader, ds_reader_batch
from reader.ds2_reader import ds2_reader, ds2_reader_batch
from reader.vllm_reader import vllm_reader, vllm_reader_batch
from utils.set_random_seed import set_random_seed
from utils.json_reader import jsonl_loader, json_loader
from evaluate import load
from openai import OpenAI

client = OpenAI(
    api_key=""
)

def qwen_judge():
    start = time.time()

    # define the parameters
    path_to_yml = "configs/config.yml"
    # model_name = "mistral_7b_model"
    model_name = "qwen_32b_model"
    # load the model
    language_model = VLLM_Model(model_name, path_to_yml)
    model = language_model.load_model()
    tokenizer = language_model.load_tokenizer()
    params = language_model.load_config()
    
    # read the content of the evaluation dataset
    json_base = jsonl_loader("results/domain/llama_base_fiqa_200.jsonl")
    json_finetune = jsonl_loader("results/domain/llama_fiqa_v3_s4000_fiqa_200.jsonl")

    # first stage generation
    query_list = []
    system_prompt_list = []
    system_prompt = "You are a helpful AI assistant."

    for i, base_output, finetine_output in zip(range(len(json_base)), json_base, json_finetune):
        question = base_output["prompt"]
        ground_truth = base_output["ground_truth"]
        finetune_answer = finetine_output["response"]
        base_answer = base_output["response"]

        prompt = domain_user_prompt_template.format(
            question=question,
            ground_truth=ground_truth,
            finetune_answer=finetune_answer,
            base_answer=base_answer
        )
        query_list.append(prompt)
        system_prompt_list.append(system_prompt)

    
    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=200)
    evaluation_result = {"finetune": 0, "base": 0}

    with jsonlines.open("results/domain/fiqa_qwen_evaluation.jsonl",mode="w") as w:
        for i in range(len(json_base)):
            query = json_base[i]["prompt"]
            output = output_list[i].strip()
            if re.search(r'answer 1', output, re.IGNORECASE):
                evaluation_result["finetune"] += 1
            elif re.search(r'answer 2', output, re.IGNORECASE):
                evaluation_result["base"] += 1

            w.write({"prompt":query, "result": output})
        
    w.close
    print(f"Evaluation Result: {evaluation_result}")
    end = time.time()
    print(f"Time taken: {end-start}")

def bertscore():
    start = time.time()

    # read the content of the evaluation dataset
    json_base = jsonl_loader("results/domain/llama_base_lifestyle_200.jsonl")
    json_finetune = jsonl_loader("results/domain/llama_lifestyle_v3_s4000_lifestyle_200.jsonl")

    bertscore = load("bertscore")
    finetuned_score = {"precision": 0, "recall": 0, "f1": 0}
    base_score = {"precision": 0, "recall": 0, "f1": 0}

    with jsonlines.open("results/domain/lifestyle_bertscore_evaluation.jsonl",mode="w") as w:
        for i, base_output, finetine_output in zip(range(len(json_base)), json_base, json_finetune):
            question = base_output["prompt"]
            ground_truth = base_output["ground_truth"]
            finetune_answer = finetine_output["response"]
            base_answer = base_output["response"]
            predictions = [finetune_answer, base_answer]
            references = [ground_truth, ground_truth]
            results = bertscore.compute(predictions=predictions, references=references, lang="en")
            w.write({
                "prompt": question,
                "ground_truth": ground_truth,
                "finetune_answer": finetune_answer,
                "base_answer": base_answer,
                "precision_finetune": results["precision"][0],
                "recall_finetune": results["recall"][0],
                "f1_finetune": results["f1"][0],
                "precision_base": results["precision"][1],
                "recall_base": results["recall"][1],
                "f1_base": results["f1"][1]
            })

            # count the scores
            finetuned_score["precision"] += results["precision"][0]
            finetuned_score["recall"] += results["recall"][0]
            finetuned_score["f1"] += results["f1"][0]
            base_score["precision"] += results["precision"][1]
            base_score["recall"] += results["recall"][1]
            base_score["f1"] += results["f1"][1]
    
    w.close()

    # average the scores
    num_samples = len(json_base)
    finetuned_score["precision"] /= num_samples
    finetuned_score["recall"] /= num_samples
    finetuned_score["f1"] /= num_samples
    base_score["precision"] /= num_samples
    base_score["recall"] /= num_samples
    base_score["f1"] /= num_samples 

    print(f"Finetuned BERTScore: {finetuned_score}")
    print(f"Base BERTScore: {base_score}")


def gpt_judge():
    start = time.time()
    # read the content of the evaluation dataset
    json_base = jsonl_loader("results/domain/llama_base_lifestyle_200.jsonl")
    json_finetune = jsonl_loader("results/domain/llama_lifestyle_v3_s4000_lifestyle_200.jsonl")

    result_list = []
    for i, base_output, finetine_output in zip(range(len(json_base)), json_base, json_finetune):
        question = base_output["prompt"]
        ground_truth = base_output["ground_truth"]
        finetune_answer = finetine_output["response"]
        base_answer = base_output["response"]

        system_prompt = "You are a helpful AI assistant."
        user_prompt = domain_user_prompt_template.format(
            question=question,
            ground_truth=ground_truth,
            finetune_answer=finetune_answer,
            base_answer=base_answer
        )

        stream = client.chat.completions.create(
            messages=[{"role": "system", "content": system_prompt},
                      {"role": "user", "content": user_prompt}
            ],
            model="gpt-4o",
            stream=True,
            seed = 42)

        output = ""
        for chunk in stream:
            if chunk.choices[0].delta.content is not None:
                output += chunk.choices[0].delta.content
        
        result_list.append(output.strip())

    evaluation_result = {"finetune": 0, "base": 0}

    with jsonlines.open("results/domain/lifestyle_gpt_evaluation.jsonl",mode="w") as w:
        for i in range(len(json_base)):
            query = json_base[i]["prompt"]
            output = result_list[i]
            if re.search(r'answer 1', output, re.IGNORECASE):
                evaluation_result["finetune"] += 1
            elif re.search(r'answer 2', output, re.IGNORECASE):
                evaluation_result["base"] += 1

            w.write({"prompt":query, "result": output})
        
    w.close
    print(f"Evaluation Result: {evaluation_result}")
    end = time.time()
    print(f"Time taken: {end-start}")

if __name__ == '__main__':
    set_random_seed(42)
    gpt_judge()


