import json
import os
import numpy as np
import re
from typing import *
from loguru import logger
from tqdm import tqdm
import traceback


def split_sentences(text):
    abbreviations = ['max.', 'eg.', 'Mrs.', 'Dr.', 'Mr.']
    
    for abbr in abbreviations:
        escaped_abbr = re.escape(abbr)
        text = re.sub(escaped_abbr, abbr.replace('.', '<DOT>'), text)

    pattern = r'[.!?。！？,;，；](?!\d)'
    sentences = re.split(pattern, text)
    
    sentences = [s.strip().replace('<DOT>', '.') for s in sentences if s.strip()]
    
    return sentences

def split_period_sentences(text):
    abbreviations = ['max.', 'eg.', 'Mrs.', 'Dr.', 'Mr.']
    
    for abbr in abbreviations:
        escaped_abbr = re.escape(abbr)
        text = re.sub(escaped_abbr, abbr.replace('.', '<DOT>'), text)

    pattern = r'[.。](?!\d)'
    sentences = re.split(pattern, text)
    
    sentences = [s.strip().replace('<DOT>', '.') for s in sentences if s.strip()]
    
    return sentences

def match_metric_name(metric: str, sentence: str) -> bool:
    pattern = r'[^\u4e00-\u9fa5a-zA-Z]'
    sentence = re.sub(pattern, '', sentence).lower()
    metric = re.sub(pattern, '', metric).lower()

    return metric in sentence

def evaluate_trend(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [0.0], [], []

    if 'steady' in attribute['type']:
        if 'steady' in sentences[0]:
            cate_correct = True
    elif 'decrease' in attribute['type']:
        if 'decreas' in sentences[0].lower():
            cate_correct = True
    elif 'increase' in attribute['type']:
        if 'increas' in sentences[0].lower():
            cate_correct = True

    num_correct = []

    # Check start point
    for sentence in sentences:
        float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
        if float_numbers is None or len(float_numbers) == 0:
            continue
        if 'start' in sentence:
            if abs(attribute['start']) < 0.5:
                if abs(float_numbers[0]) < 0.5:
                    num_correct.append(1.0)
                else:
                    num_correct.append(0.0)
            else:
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['start']) / abs(attribute['start']))))
            break
    else:
        num_correct.append(0.0)

    # Check amplitude
    if attribute['type'] != 'keep steady':
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'change value' in sentence or 'from left to right' in sentence:
                if abs(attribute['amplitude']) < 0.5:
                    if abs(float_numbers[0]) < 0.5:
                        num_correct.append(1.0)
                    else:
                        num_correct.append(0.0)
                else:
                    num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['amplitude']) / abs(attribute['amplitude']))))
                break
        else:
            num_correct.append(0.0)

    return [cate_correct], num_correct, [], []

def evaluate_season(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [0.0], [], []

    if 'no' in attribute['type']:
        if 'no periodic' in sentences[0].lower():
            cate_correct = True
    else:
        if 'no' not in sentences[0].lower() and 'periodic' in sentences[0].lower():
            cate_correct = True

    num_correct = []

    if attribute['type'] != 'no periodic fluctuation':
        # Check period
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'each period' in sentence:
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['period']) / abs(attribute['period']))))
                break
        else:
            num_correct.append(0.0)

        # Check amplitude
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'amplitude' in sentence:
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['amplitude']) / abs(attribute['amplitude']))))
                break
        else:
            num_correct.append(0.0)
    else:
        num_correct = []

    return [cate_correct], num_correct, [], []

def evaluate_noise(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [0.0], [], []

    if 'almost no' in attribute['type']:
        if 'no noise' in sentences[0].lower():
            cate_correct = True
    else:
        if 'noisy' in sentences[0].lower():
            cate_correct = True

    num_correct = []

    # Check period
    if 'noisy' in attribute['type']:
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'standard' in sentence.lower() or 'std' in sentence.lower():
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['std']) / abs(attribute['std']))))
                break
        else:
            num_correct.append(0.0)

    return [cate_correct], num_correct, [], []

def evaluate_local(answer: str, attribute: dict, cols: List[str]):
    cate_correct = []
    num_correct = []

    # Split into facts
    if len(attribute) == 0 and "no " in answer.lower():
        return [True], [], [], []

    for feat in attribute:
        matched_flag = False
        pos_numerical = 0.0
        amp_numerical = 0.0
        for fact in re.split(r'[;；]', answer):
            sentences = re.split(r'[，。,;；]', fact)
            if type(feat['type']) == str:
                feat['type'] = [feat['type']]
            if any(i in sentences[0].lower() for i in feat['type']):
                # Check period and amplitude
                for sentence in sentences:
                    float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
                    if float_numbers is None or len(float_numbers) == 0:
                        continue
                    if 'position' in sentence.lower() or 'around point' in sentence.lower():
                        if abs(float_numbers[0] - feat['position']) > 64:
                            continue
                        pos_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['position']) / abs(feat['position'])))
                        matched_flag = True
                    if matched_flag and 'amplitude' in sentence.lower():
                        amp_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['amplitude']) / abs(feat['amplitude'])))
                if matched_flag:
                    break
        cate_correct.append(matched_flag)
        num_correct.append(pos_numerical)
        num_correct.append(amp_numerical)

    return cate_correct, num_correct, [], []

def evaluate_shape_correlation(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [False], [], [0.0], [{}]

    if attribute['label']:
        if 'yes' in sentences[0].lower():
            cate_correct = True
    else:
        if 'no' in sentences[0].lower():
            cate_correct = True

    return [cate_correct], [], [], []

def evaluate_local_correlation(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_period_sentences(answer)

    if len(sentences) == 0:
        return [False], [], [0.0], [{}]

    if attribute['label']:
        if 'yes' in sentences[0].lower():
            # Check correlation type
            label_cols = set(map(tuple, attribute['pair']))
            answer_cols = set()

            # Split into facts
            for fact in sentences[1].split(';'):
                items = fact.strip().split(',')
                if len(items) == 2:
                    for col in cols:
                        if match_metric_name(col, items[0].strip()):
                            answer_cols.add((col, items[1].strip()))

            if label_cols == answer_cols:
                cate_correct = True
    else:
        if 'no' in sentences[0].lower():
            cate_correct = True

    return [cate_correct], [], [], []

def evaluate_shape_cluster(answer: str, attribute: dict, cols: List[str]):
    cate_correct = 0.0
    num_correct = []

    label_cols = set(attribute['cols'])
    answer_cols = set()

    sentences = split_period_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [], [0.0], [{}]

    # Split into facts
    for fact in split_period_sentences(answer)[0].split(','):
        fact = fact.strip()
        for col in cols:
            if match_metric_name(col, fact):
                answer_cols.add(col)

    # Calculate f1-score for label and answer
    tp = len(label_cols & answer_cols)
    fp = len(answer_cols - label_cols)
    fn = len(label_cols - answer_cols)
    if tp + fp + fn > 0:
        cate_correct = 2 * tp / (2 * tp + fp + fn)

    return [cate_correct], [], [], []

def evaluate_local_cluster(answer: str, attribute: dict, cols: List[str]):
    cate_correct = 0.0
    num_correct = []

    label_cols = set(zip(attribute['cols'], [i[1] for i in attribute['col_idx']]))
    answer_cols = set()

    sentences = split_period_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [], [0.0], [{}]

    # Split into facts
    for fact in split_period_sentences(answer)[0].split(';'):
        items = fact.strip().rsplit(',', 1)
        if len(items) == 2:
            for col in cols:
                if match_metric_name(col, items[0].strip()):
                    answer_cols.add((col, items[1].strip()))

    # Calculate f1-score for label and answer
    tp = len(label_cols & answer_cols)
    fp = len(answer_cols - label_cols)
    fn = len(label_cols - answer_cols)
    if tp + fp + fn > 0:
        cate_correct = 2 * tp / (2 * tp + fp + fn)

    return [cate_correct], [], [], []

def evaluate_deductive(answer: str, attribute: str, cols: List[str]):
    labels = split_sentences(attribute)
    sentences = split_sentences(answer)

    cur_reason_correct = 1.0
    if labels[0].lower().strip() in ['yes', 'no']:
        if sentences[0].lower().strip() != labels[0].lower().strip():
            cur_reason_correct = 0.0
        ragas_detail = {"label": labels[0], "response": sentences[0]}
    else:
        ragas_correct, ragas_detail = calculate_ragas_score(
                            question='According to the previous information, please answer Yes or No and explain it in detail.',
                            response=answer,
                            label=attribute
                        )
        cur_reason_correct = ragas_correct
    return [], [], [cur_reason_correct], [ragas_detail]

def evaluate_causal(answer: str, attribute: str, cols: List[str]):
    label = split_sentences(attribute)[0].lower().strip()
    answer_choice = split_sentences(answer)[0].lower().strip()
    if match_metric_name(label, answer_choice):
        reason_correct = 1.0
    else:
        reason_correct = 0.0
    return [], [], [reason_correct], [{'label': label, 'response': answer_choice}]

def evaluate_MCQ2(answer: str, attribute: str, cols: List[str]):
    if match_metric_name(attribute, answer):
        reason_correct = 1.0
    else:
        reason_correct = 0.0
    return [], [], [reason_correct], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_pattern_recognition(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_numerical_judgement(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_causal(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_calculation(answer: str, attribute: float, cols: List[str]):
    result = 0.0
    try:
        response_value = float(answer)
        label_value = float(attribute)
        if label_value > 1e-2:
            result = max(0.0, min(1.0, 1.0 - abs(response_value - label_value) / abs(label_value)))
        else:
            result = 1.0 if abs(response_value - label_value) < 1e-2 else 0.0
    except Exception as err:
        result = 0.0

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_pattern_recognition(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_numerical_judgement(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_causal(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_compare_choice(answer: str, attribute: str, cols: List[str]):
    if match_metric_name(answer, attribute):
        reason_correct = 1.0
    else:
        reason_correct = 0.0

    return [], [], [reason_correct], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_calculation(answer: str, attribute: float, cols: List[str]):
    result = 0.0
    try:
        response_value = float(answer)
        label_value = float(attribute)
        if label_value > 1e-2:
            result = max(0.0, min(1.0, 1.0 - abs(response_value - label_value) / abs(label_value)))
        else:
            result = 1.0 if abs(response_value - label_value) < 1e-2 else 0.0
    except Exception as err:
        result = 0.0

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_compare_pattern_recognition(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_compare_causal(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def ability_type_to_func(ability_type: str):
    return eval("evaluate_" + ability_type.replace('-', '_'))

def reward_qa(answer: str, label: str):
    # Extract from deep thinking
    if "</think>" in answer:
        answer = answer.split("</think>")[-1].strip()
    answer = answer.replace('<answer>', '').replace('</answer>', '').strip()
    if '\\answer{' in answer and '}' in answer:
        answer = re.findall(r'\\answer\{(.*)\}', answer, re.DOTALL)[0]
    elif '\\answer{' in answer:
        answer = answer.split('\\answer{')[-1].strip()

    # Get ability types
    sample = json.loads(label)
    ability = sample['ability_type']
    result = {}
    evaluate_func = ability_type_to_func(ability)
    cate_correct, num_correct, reason_correct, reason_detail = evaluate_func(answer, sample['attribute'], sample['cols'])

    if ability in result:
        # Extent current result to existed
        cate_correct = result[ability][0] + cate_correct
        num_correct = result[ability][1] + num_correct
        reason_correct = result[ability][2] + reason_correct
        reason_detail = result[ability][3] + reason_detail
    result[ability] = (cate_correct, num_correct, reason_correct, reason_detail)  

    return result
