import numpy as np
import re
import random
import torch
from openai import OpenAI
from global_vars import PROMPT, TASK_NAME
def set_seed(seed: int = 42):
    random.seed(seed)          # Python 内置随机数
    np.random.seed(seed)       # 如果后面会用 numpy
    torch.manual_seed(seed)    # 如果会用 torch
    torch.cuda.manual_seed_all(seed)  # 多GPU


def reformulate_dialog(dialog):
    message = dialog["messages"]
    out_message = []
    for role, content in zip(message["role"], message["content"]):
        out_message.append({
            "role": role,
            "content": content
        })
    return out_message

def mean_absolute_error(y_true, y_pred):
    """
    Compute Mean Absolute Error (MAE) between ground truth and predictions.

    Args:
        y_true (list[int] or list[float]): Ground truth values
        y_pred (list[int] or list[float]): Predicted values

    Returns:
        float: MAE score
    """
    if len(y_true) != len(y_pred):
        raise ValueError("y_true and y_pred must have the same length.")
    return sum(abs(t - p) for t, p in zip(y_true, y_pred)) / len(y_true)



def normalized_inverse_error(y_true, y_pred):
    """
    Compute a soft score in [0, 1] where closer predictions get higher scores.
    Formula: s = average( 1 / (1 + |pred - true| / max(1, true)) )

    Args:
        y_true (list[int] or list[float]): Ground truth values
        y_pred (list[int] or list[float]): Predicted values

    Returns:
        float: Mean normalized inverse error score
    """
    if len(y_true) != len(y_pred):
        raise ValueError("y_true and y_pred must have the same length.")
    scores = []
    for t, p in zip(y_true, y_pred):
        denom = 1 + abs(t - p) / max(1, t)
        scores.append(1 / denom)
    return sum(scores) / len(scores)

def accuracy(y_true, y_pred):
    """
    Compute accuracy between ground truth and predictions.
    """
    return np.mean(np.array(y_true) == np.array(y_pred))


def re_response_number_extraction(output):
    extract_pattern = r'\b(?:[0-9]|10|11|12|13|14|15|16|17|18|19|20)\b'
    match = re.search(extract_pattern, output)
    if match:
        return int(match.group())
    return -100

def llm_as_judge(output, task_name):
    if task_name == TASK_NAME.OPINION_COUNTING:
        prompt = PROMPT.llm_as_judge_opinion_counting.format(model_output=output)
    elif task_name == TASK_NAME.OPINION_MATCHING:
        prompt = PROMPT.llm_as_judge_opinion_matching.format(model_output=output)
    elif task_name == TASK_NAME.POLARITY_CHECK:
        prompt = PROMPT.llm_as_judge_polarity_check.format(model_output=output)
    client = OpenAI(api_key="")
  
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.5,
        max_tokens=100
    )
    return response.choices[0].message.content.strip()


def re_response_pro_con_extraction(output):
    extract_pattern = r'(pro|con)'
    match = re.search(extract_pattern, output, re.IGNORECASE)
    if match:
        return match.group().lower()
    return None

def opinion_matching_evaluation(pred, gt):
    valid_count = 0
    pred = []
    gt = []
    # filtering out the None
    for p, g in zip(pred, gt):
        extracted = re_response_number_extraction(p)
        if extracted is not None:
            pred.append(extracted)
            gt.append(g)
            valid_count += 1

    return_dict = {
        "accuracy": accuracy(pred, gt),
        "valid_count": valid_count,
        "pred": pred,
        "gt": gt
    }
            
    return return_dict

def opinion_counting_evaluation(pred, gt):
    valid_count = 0
    pred = []
    gt = []
    # filtering out the None
    for p, g in zip(pred, gt):
        extracted = re_response_number_extraction(p)
        if extracted is not None:
            pred.append(extracted)
            gt.append(g)
            valid_count += 1

    return_dict = {
        "accuracy": accuracy(pred, gt),
        "mae": mean_absolute_error(pred, gt),
        "normalized_inverse_error": normalized_inverse_error(pred, gt),
        "valid_count": valid_count,
        "pred": pred,
        "gt": gt
    }

def polarity_check_evaluation(pred, gt):
    valid_count = 0
    pred = []
    gt = []
    # filtering out the None
    for p, g in zip(pred, gt):
        extracted = re_response_pro_con_extraction(p)
        if extracted is not None:
            pred.append(extracted)
            gt.append(g)
            valid_count += 1

    return_dict = {
        "accuracy": accuracy(pred, gt),
        "valid_count": valid_count,
        "pred": pred,
        "gt": gt
    }
    return return_dict


if __name__ == "__main__":
    print(re_response_number_extraction("0"))