import os
import sys
import json
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, cohen_kappa_score, classification_report

from models import MODELS
from models import *
from openai import OpenAI

DATA_PATH = "../data"
your_openai_key = "xxx"
your_gemini_key = ""
your_mistral_key = ""

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def create_model(model_name, args):
    """ Init the model from the model_name and return it."""
    if model_name in MODELS:
        class_name = MODELS[model_name]["model_class"]
        cls = getattr(sys.modules[__name__], class_name)
        return cls(model_name, args)

    raise ValueError(f"Unknown Model '{model_name}'")

def map_answer_to_label(answers):
    """ Map the textual answers to the labels."""
    labels = []
    for answer in answers:
        if "yes" in answer.lower():
            labels.append(0)
        elif "not related" in answer.lower():
            labels.append(2)
        elif "no" in answer.lower():
            labels.append(1)
        else:  # invalid labels
            labels.append(-1)
    invalid_labels = len([i for i in range(len(labels)) if labels[i] == -1])
    print("# of invalid labels: ", invalid_labels)

    return labels

def compute_metrics(results, metrics = ["accuracy"]): # 有的是多分类问题，有的是二分类问题
    """ Compute the metrics for the LLM evaluation results and return them."""
    
    prediction = map_answer_to_label(results["answer"].tolist())
    ground_truth = map_answer_to_label(results["label"].tolist())
    
    metric_results = {
        "accuracy": [],
        "precision": [],
        "recall": [],
        "f1": []
    }

    repeat_predictions, repeat_ground_truth = [], []
    repeat_times = results["eval_repeats"][0]
    for repeat in range(repeat_times):
        current_prediction = [prediction[i] for i in range(len(prediction)) if results["current_repeat"][i] == repeat]
        current_ground_truth = [ground_truth[i] for i in range(len(ground_truth)) if results["current_repeat"][i] == repeat]
        repeat_predictions.append(current_prediction)
        repeat_ground_truth.append(current_ground_truth)

        accuracy = accuracy_score(current_ground_truth, current_prediction)
        precision = precision_score(current_ground_truth, current_prediction, average="micro")
        recall = recall_score(current_ground_truth, current_prediction, average="micro")
        f1 = f1_score(current_ground_truth, current_prediction, average="micro")
        report = classification_report(current_ground_truth, current_prediction)
        print(f"classification report for repeat {repeat}:\n {report}")

        metric_results["accuracy"].append(round(accuracy, 4))
        metric_results["precision"].append(round(precision, 4))
        metric_results["recall"].append(round(recall, 4))
        metric_results["f1"].append(round(f1, 4))
    
    # compute Kappa score to measure the agreement between the repeated evaluations
    metric_results["kappa"] = []
    for i in range(repeat_times):
        for j in range(i+1, repeat_times):
            kappa = cohen_kappa_score(repeat_predictions[i], repeat_predictions[j])
            metric_results["kappa"].append(round(kappa, 4))
    
    metrics = [m for m in metric_results]
    for metric in metrics:
        metric_results[f"avg_{metric}"] = round(sum(metric_results[metric]) / (len(metric_results[metric])+1e-5) , 4)
        metric_results[f"std_{metric}"] = round(np.std(metric_results[metric]), 4)

    # compute the result for each rule
    value_results = {}
    unique_values = sorted(results["value"].unique().tolist())
    for value in unique_values:
        value_prediction = [prediction[i] for i in range(len(prediction)) if results["value"][i] == value]
        value_ground_truth = [ground_truth[i] for i in range(len(ground_truth)) if results["value"][i] == value]

        accuracy = accuracy_score(value_ground_truth, value_prediction)
        precision = precision_score(value_ground_truth, value_prediction, average="micro")
        recall = recall_score(value_ground_truth, value_prediction, average="micro")
        f1 = f1_score(value_ground_truth, value_prediction, average="micro")

        value = value.split(":", 1)[0]
        value_results[value] = {
            "accuracy": round(accuracy, 4),
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4)
        }
    
        print(f"Metrics for value {value}: {value_results[value]}")
    
    return metric_results

def compute_cosine_similarity(v1, v2_list):
    """ Compute the cosine similarity between two vectors."""
    v1 = np.array(v1)
    v2_list = np.array(v2_list)
    similarity = np.dot(v1, v2_list.T) / (np.linalg.norm(v1) * np.linalg.norm(v2_list, axis=1) + 1e-7)
    return similarity

def call_llm_apis(prompt, model_name="gpt-4"):
    if model_name == "gpt-4":
        return call_openai_gpt(prompt, model_name)
    elif model_name == "gpt-35-turbo":
        return call_openai_gpt(prompt, model_name)
    elif model_name == "gemini-pro":
        return call_gemini(prompt, model_name)
    elif model_name == "mistral-large":
        return call_mistral(prompt, model_name)

def call_openai_gpt(prompt, model_name="gpt-4"):
    client = OpenAI(api_key=your_openai_key)
    message = [{"role": "user", "content": prompt}]
    max_tokens = 2048
    if model_name == "gpt-35-turbo":
        model_name = "gpt-3.5-turbo"
        max_tokens = 1024
    response = client.chat.completions.create(
        model = model_name,
        messages = message,
        temperature = 1.0,
        top_p = 1.0,
        max_tokens = max_tokens,
        frequency_penalty = 0.0,
        presence_penalty = 0.0,
    )
    answer = response.choices[0].message.content.strip()

    return answer

def call_gemini(prompt, model_name="gemini-pro"):
    genai.configure(api_key=your_gemini_key)
    model = genai.GenerativeModel("gemini-pro")
    response = model.generate_content(prompt,
                                    generation_config = genai.types.GenerationConfig(
                                    max_output_tokens=512,
                                    temperature=1.0),
                                    safety_settings = [
                                        {
                                            "category": "HARM_CATEGORY_HARASSMENT",
                                            "threshold": "BLOCK_NONE",
                                            "probability": "NEGLIGIBLE"
                                        },
                                        {
                                            "category": "HARM_CATEGORY_HATE_SPEECH",
                                            "threshold": "BLOCK_NONE",
                                            "probability": "NEGLIGIBLE"
                                        },
                                        {
                                            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                                            "threshold": "BLOCK_NONE",
                                            "probability": "NEGLIGIBLE"
                                        },
                                        {
                                            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                                            "threshold": "BLOCK_NONE",
                                            "probability": "NEGLIGIBLE"
                                        },
                                    ])
    answer = response.text.strip()

def call_mistral(prompt, model_name="mistral-large"):
    api_key = your_mistral_key
    client = MistralClient(api_key=api_key)
    messages = [ChatMessage(role="user", content=f"{prompt}")]
    response = client.chat(
        model = "mistral-large-latest",
        messages = messages,
        temperature = 1.0,
        top_p = 1.0,
        max_tokens = 1024,
    )
    answer = response.choices[0].message.content.strip()
    return answer

def get_text_embeddings(dataset):
    client = OpenAI(api_key = your_openai_key)

    if dataset == "beavertails":
        data_path = f"{DATA_PATH}/beavertails/train.jsonl"
        lines = [json.loads(l) for l in open(data_path, "r").readlines()]
        line_embeds = []
        for line in tqdm(lines):
            prompt = line["prompt"]
            response = line["response"]
            text = f"Question: {prompt}. Response: {response}".replace("\n", " ")
            embed = client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding
            line_embeds.append({"prompt": prompt, "response": response, "embed": embed})
    elif dataset == "value_fulcra":
        data_path = f"{DATA_PATH}/value_fulcra/train.csv"
        data_df = pd.read_csv(data_path)
        line_embeds = []
        for idx, (_, line) in tqdm(enumerate(data_df.iterrows())):
            text = line["dialogue"]
            embed = client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding
            line_embeds.append({"dialogue": text, "embed": embed})
    elif dataset == "denevil":
        data_path = f"{DATA_PATH}/denevil/train.csv"
        data_df = pd.read_csv(data_path)
        line_embeds = []
        for idx, (_, line) in tqdm(enumerate(data_df.iterrows())):
            text = line["scene"]
            embed = client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding
            line_embeds.append({"scene": text, "embed": embed})
        
    with open(f"{DATA_PATH}/{dataset}/train_embed.jsonl", "w") as fw:
        for line in line_embeds:
            fw.write(json.dumps(line) + "\n")