import os
import json
import argparse
import numpy as np
import openai
from datasets import load_dataset
# from alpaca_farm.auto_annotations import alpaca_leaderboard
import datasets
import re

from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)
import os
# os.environ["HF_DATASETS_CACHE"] = "/home/sjh/.cache/huggingface/datasets"

# import debugpy
# try:
#     # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
#     debugpy.listen(("localhost", 9501))
#     print("Waiting for debugger attach")
#     debugpy.wait_for_client()
# except Exception as e:
#     pass

openai.api_key = "sk-kILnStOYk4yJXG1tafaYe4iH8ltwkXugMsPLn05iWEJLPLiC"
openai.api_base = "https://api.dwyu.top/v1"

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
    "knowledge_memorization": qa_f1_score,
    "knowledge_understanding": qa_f1_score,
    "longform_qa": rouge_score,
    "finance_qa": rouge_score,
}

def parse_args(args=None):
    parser = argparse.ArgumentParser(description="Evaluate texts generated by every method")

    parser.add_argument(
        "--input_dir",
        type=str,
        default="/data2/tsq/WaterBench/pred/llama2-7b-chat-4k_no_g0.5_d5.0")
    args = parser.parse_args()

    return args

# def scorer_e(dataset, predictions, answers, lengths, all_classes):
#     scores = {"0-4k": [], "4-8k": [], "8k+": []}
#     for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
#         score = 0.
#         if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
#             prediction = prediction.lstrip('\n').split('\n')[0]
#         for ground_truth in ground_truths:
#             score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
#         if length < 4000:
#             scores["0-4k"].append(score)
#         elif length < 8000:
#             scores["4-8k"].append(score)
#         else:
#             scores["8k+"].append(score)
#     for key in scores.keys():
#         scores[key] = round(100 * np.mean(scores[key]), 2)
#     return scores
# 模拟no_repeat_ngram_size=4的效果
def clean_repetitions(text: str, 
                     word_ngram_size: int = 3,
                     char_ngram_size: int = 6,
                     min_word_repeat: int = 2,
                     min_char_repeat: int = 3,
                     enable_char_level: bool = True,
                     enable_regex_patterns: bool = True,
                     preserve_spaces: bool = True) -> str:
    """
    综合的重复内容清理函数，处理多种类型的重复模式
    
    Args:
        text: 输入文本
        word_ngram_size: 单词级别n-gram大小
        char_ngram_size: 字符级别n-gram大小
        min_word_repeat: 单词重复的最小次数才触发清理
        min_char_repeat: 字符重复的最小次数才触发清理
        enable_char_level: 是否启用字符级别检查
        enable_regex_patterns: 是否启用正则模式检查
        preserve_spaces: 是否保留空格结构
    
    Returns:
        清理后的文本
    """
    if not text or not text.strip():
        return text
    
    # 1. 首先处理明显的重复模式（正则）
    if enable_regex_patterns:
        text = _clean_regex_patterns(text)
    
    # 2. 单词级别的重复检查
    text = _clean_word_repetitions(text, word_ngram_size, min_word_repeat)
    
    # 3. 字符级别的重复检查
    if enable_char_level:
        text = _clean_char_repetitions(text, char_ngram_size, min_char_repeat, preserve_spaces)
    
    # 4. 最后清理多余的空格
    text = _normalize_spaces(text)
    
    return text

def _clean_regex_patterns(text: str) -> str:
    """使用正则表达式清理明显的重复模式"""
    
    # 清理单个字符的长重复 (如 "aaaaaa" -> "aa")
    text = re.sub(r'(.)\1{4,}', r'\1\1', text)
    
    # 清理短字符串的重复 (如 "abcabcabc" -> "abc")
    # 匹配2-10个字符的重复模式，重复3次以上
    for pattern_len in range(2, 11):
        pattern = r'(.{' + str(pattern_len) + r'})\1{2,}'
        text = re.sub(pattern, r'\1', text)
    
    # 清理单词的直接重复 (如 "word word word" -> "word")
    text = re.sub(r'\b(\w+)(\s+\1\b){2,}', r'\1', text)
    
    return text

def _clean_word_repetitions(text: str, ngram_size: int, min_repeat: int) -> str:
    """清理单词级别的n-gram重复"""
    words = text.split()
    if len(words) < ngram_size:
        return text
    
    ngram_counts = {}
    result = []
    
    for i in range(len(words)):
        if i < ngram_size - 1:
            result.append(words[i])
        else:
            current_ngram = tuple(words[i-ngram_size+1:i+1])
            ngram_counts[current_ngram] = ngram_counts.get(current_ngram, 0) + 1
            
            if ngram_counts[current_ngram] >= min_repeat:
                # 发现重复，停止添加
                break
            else:
                result.append(words[i])
    
    return ' '.join(result)

def _clean_char_repetitions(text: str, ngram_size: int, min_repeat: int, preserve_spaces: bool) -> str:
    """清理字符级别的n-gram重复"""
    if len(text) < ngram_size:
        return text
    
    # 如果需要保留空格结构，分别处理每个词
    if preserve_spaces:
        parts = text.split(' ')
        cleaned_parts = []
        
        for part in parts:
            if part:  # 非空部分
                cleaned_part = _clean_char_repetitions_core(part, ngram_size, min_repeat)
                cleaned_parts.append(cleaned_part)
            else:
                cleaned_parts.append(part)
        
        return ' '.join(cleaned_parts)
    else:
        return _clean_char_repetitions_core(text, ngram_size, min_repeat)

def _clean_char_repetitions_core(text: str, ngram_size: int, min_repeat: int) -> str:
    """字符级别重复检查的核心逻辑"""
    if len(text) < ngram_size:
        return text
    
    ngram_counts = {}
    result = []
    
    for i in range(len(text)):
        if i < ngram_size - 1:
            result.append(text[i])
        else:
            current_ngram = text[i-ngram_size+1:i+1]
            ngram_counts[current_ngram] = ngram_counts.get(current_ngram, 0) + 1
            
            if ngram_counts[current_ngram] >= min_repeat:
                # 发现重复，停止添加
                break
            else:
                result.append(text[i])
    
    return ''.join(result)

def _normalize_spaces(text: str) -> str:
    """标准化空格"""
    # 移除多余的空格，但保留单个空格
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

# 便利函数，提供常用的预设配置
def clean_repetitions_strict(text: str) -> str:
    """严格模式：更容易触发重复检测"""
    return clean_repetitions(
        text,
        word_ngram_size=2,
        char_ngram_size=4,
        min_word_repeat=2,
        min_char_repeat=2,
        enable_char_level=True,
        enable_regex_patterns=True
    )

score_list = []
def scorer(dataset, predictions, answers, all_classes):
    total_score = 0.

    for (prediction, ground_truths) in zip(predictions, answers):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        # prediction = clean_repetitions(prediction)
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
            score_list.append(score)
        total_score += score
    return round(100 * total_score / len(predictions), 2)

def alpacafarm_score(prompts, predictions, model_name):
    # outputs should be a list of json as such:
    # [{'instruction': 'What are the names of some famous actors that started their careers on Broadway?', 'input': '', 'output': 'Some famous actors that started their careers on Broadway are Hugh Jackman, Meryl Streep, Denzel Washington, Audra McDonald, and Lin-Manuel Miranda.', 'generator': 'gpt-3.5-turbo-0301', 'dataset': 'helpful_base', 'datasplit': 'eval'},
    # ...]
    my_outputs = []
    
    alpaca_eval_data = load_dataset("tatsu-lab/alpaca_farm", "alpaca_farm_evaluation")["eval"]
    # alpaca_eval_data = []
    # with open("data/WaterBench/5-1_alpacafarm.jsonl", "r") as f:
    #     for line in f:
    #         alpaca_eval_data.append(json.loads(line))
    # for i, item in enumerate(alpaca_eval_data):
    #     prompt = item['context']
    #     _input = item['input'] if 'input' in item else ''
    #     prediction = predictions[i]
    #     my_outputs.append({
    #         "instruction": prompt,
    #         "input": _input,
    #         "output": prediction,
    #         "generator": model_name,
    #         "dataset": "alpaca_farm",
    #         "datasplit": "eval"
    #     })
    for i, line in enumerate(alpaca_eval_data):
        # json_obj = json.loads(line)
        prompt = line["instruction"]
        _input = line["input"]
        prediction = predictions[i]
        my_outputs.append({"instruction": prompt, "input": _input, "generator": model_name, "output": prediction})
    print("my_outputs[0] is:", my_outputs[0])
    df_results = alpaca_leaderboard(
        path_or_all_outputs=my_outputs,
        name=model_name,
        is_add_reference_methods=False,
        annotators_config = "greedy_gpt4/configs.yaml"
    )
    print(df_results)
    score = round(df_results["win_rate"].iloc[0], 2)
    # score = df_results.to_string(float_format="%.2f")
    return score


if __name__ == '__main__':
    args = parse_args()
    scores = dict()
    # get all files from input_dir
    files = os.listdir(args.input_dir)
    model_name = args.input_dir.split("/")[-1]
    # get all json files
    # json_files = [f for f in files if f.endswith(".jsonl")]
    save_dir =  os.path.join(args.input_dir, "eval")
    os.makedirs(save_dir, exist_ok=True)
    print("Evaluating on:", files)
    data_names = ["knowledge_memorization","knowledge_understanding","longform_qa",
                        "finance_qa","hotpotqa","lcc", "multi_news", "qmsum","alpacafarm"]
    for dataset in data_names:
        # if not json_file.endswith("jsonl"):
        #     continue
        # print(f"{json_file} has began.........")
        # # read jsons
        # dataset = json_file.split(".")[0]
        predictions, answers, lengths, all_classes = [], [], [], []
        data_path = os.path.join(args.input_dir, dataset + ".jsonl")
        if not os.path.exists(data_path):
            print(f"File {data_path} does not exist, skipping...")
            continue
        with open(data_path, "r") as f:
            # lines
            lines = f.readlines()
            # texts
            prompts = [json.loads(line)["prompt"] for line in lines]
            predictions = [json.loads(line)["pred"] for line in lines]
            answers = [json.loads(line)["answers"] for line in lines]
            all_classes = json.loads(lines[0])["all_classes"]
            print(f"predictions[0] is: {predictions[0]}")
            if dataset == "alpacafarm":
                pass
                # score = alpacafarm_score(prompts, predictions, model_name)
            else:
                score = scorer(dataset, predictions, answers, all_classes)
            scores[dataset] = score
    # save
    out_path = os.path.join(save_dir, "result.json")
    with open(out_path, "w") as f:
        json.dump(scores, f, ensure_ascii=False, indent=4)

    # with open('output2.json', 'w', encoding='utf-8') as f:
    #     json.dump(score_list, f, ensure_ascii=False, indent=4)