import json
import os
import numpy as np
import re
from typing import *
from loguru import logger
from tqdm import tqdm
import traceback
from multiprocessing import Pool


def split_sentences(text):
    abbreviations = ['max.', 'eg.', 'Mrs.', 'Dr.', 'Mr.']
    
    for abbr in abbreviations:
        escaped_abbr = re.escape(abbr)
        text = re.sub(escaped_abbr, abbr.replace('.', '<DOT>'), text)

    pattern = r'[.!?。！？,;，；](?!\d)'
    sentences = re.split(pattern, text)
    
    sentences = [s.strip().replace('<DOT>', '.') for s in sentences if s.strip()]
    
    return sentences

def split_period_sentences(text):
    abbreviations = ['max.', 'eg.', 'Mrs.', 'Dr.', 'Mr.']
    
    for abbr in abbreviations:
        escaped_abbr = re.escape(abbr)
        text = re.sub(escaped_abbr, abbr.replace('.', '<DOT>'), text)

    pattern = r'[.。](?!\d)'
    sentences = re.split(pattern, text)
    
    sentences = [s.strip().replace('<DOT>', '.') for s in sentences if s.strip()]
    
    return sentences

def match_metric_name(metric: str, sentence: str) -> bool:
    pattern = r'[^\u4e00-\u9fa5a-zA-Z]'
    sentence = re.sub(pattern, '', sentence).lower()
    metric = re.sub(pattern, '', metric).lower()

    return metric in sentence

def evaluate_trend(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [0.0], [], []

    if 'steady' in attribute['type']:
        if 'steady' in sentences[0]:
            cate_correct = True
    elif 'decrease' in attribute['type']:
        if 'decreas' in sentences[0].lower():
            cate_correct = True
    elif 'increase' in attribute['type']:
        if 'increas' in sentences[0].lower():
            cate_correct = True

    num_correct = []

    # Check start point
    for sentence in sentences:
        float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
        if float_numbers is None or len(float_numbers) == 0:
            continue
        if 'start' in sentence:
            if abs(attribute['start']) < 0.5:
                if abs(float_numbers[0]) < 0.5:
                    num_correct.append(1.0)
                else:
                    num_correct.append(0.0)
            else:
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['start']) / abs(attribute['start']))))
            break
    else:
        num_correct.append(0.0)

    # Check amplitude
    if attribute['type'] != 'keep steady':
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'change value' in sentence or 'from left to right' in sentence:
                if abs(attribute['amplitude']) < 0.5:
                    if abs(float_numbers[0]) < 0.5:
                        num_correct.append(1.0)
                    else:
                        num_correct.append(0.0)
                else:
                    num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['amplitude']) / abs(attribute['amplitude']))))
                break
        else:
            num_correct.append(0.0)

    return [cate_correct], num_correct, [], []

def evaluate_season(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [0.0], [], []

    if 'no' in attribute['type']:
        if 'no periodic' in sentences[0].lower():
            cate_correct = True
    else:
        if 'no' not in sentences[0].lower() and 'periodic' in sentences[0].lower():
            cate_correct = True

    num_correct = []

    if attribute['type'] != 'no periodic fluctuation':
        # Check period
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'each period' in sentence:
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['period']) / abs(attribute['period']))))
                break
        else:
            num_correct.append(0.0)

        # Check amplitude
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'amplitude' in sentence:
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['amplitude']) / abs(attribute['amplitude']))))
                break
        else:
            num_correct.append(0.0)
    else:
        num_correct = []

    return [cate_correct], num_correct, [], []

def evaluate_noise(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [0.0], [], []

    if 'almost no' in attribute['type']:
        if 'no noise' in sentences[0].lower():
            cate_correct = True
    else:
        if 'noisy' in sentences[0].lower():
            cate_correct = True

    num_correct = []

    # Check period
    if 'noisy' in attribute['type']:
        for sentence in sentences:
            float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
            if float_numbers is None or len(float_numbers) == 0:
                continue
            if 'standard' in sentence.lower() or 'std' in sentence.lower():
                num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['std']) / abs(attribute['std']))))
                break
        else:
            num_correct.append(0.0)

    return [cate_correct], num_correct, [], []

def evaluate_local(answer: str, attribute: dict, cols: List[str]):
    cate_correct = []
    num_correct = []

    # Split into facts
    for feat in attribute:
        matched_flag = False
        pos_numerical = 0.0
        amp_numerical = 0.0
        for fact in re.split(r'[;；]', answer):
            sentences = re.split(r'[，。,;；]', fact)
            if type(feat['type']) == str:
                feat['type'] = [feat['type']]
            if any(i in sentences[0].lower() for i in feat['type']):
                # Check period and amplitude
                for sentence in sentences:
                    float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence)))
                    if float_numbers is None or len(float_numbers) == 0:
                        continue
                    if 'position' in sentence.lower() or 'around point' in sentence.lower():
                        if abs(float_numbers[0] - feat['position']) > 64:
                            continue
                        pos_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['position']) / abs(feat['position'])))
                        matched_flag = True
                    if matched_flag and 'amplitude' in sentence.lower():
                        amp_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['amplitude']) / abs(feat['amplitude'])))
                if matched_flag:
                    break
        cate_correct.append(matched_flag)
        num_correct.append(pos_numerical)
        num_correct.append(amp_numerical)

    return cate_correct, num_correct, [], []

def evaluate_shape_correlation(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_sentences(answer)

    if len(sentences) == 0:
        return [False], [], [0.0], [{}]

    if attribute['label']:
        if 'yes' in sentences[0].lower():
            cate_correct = True
    else:
        if 'no' in sentences[0].lower():
            cate_correct = True

    return [cate_correct], [], [], []

def evaluate_local_correlation(answer: str, attribute: dict, cols: List[str]):
    cate_correct = False
    sentences = split_period_sentences(answer)

    if len(sentences) == 0:
        return [False], [], [0.0], [{}]

    if attribute['label']:
        if 'yes' in sentences[0].lower():
            # Check correlation type
            label_cols = set(map(tuple, attribute['pair']))
            answer_cols = set()

            # Split into facts
            for fact in sentences[1].split(';'):
                items = fact.strip().split(',')
                if len(items) == 2:
                    for col in cols:
                        if match_metric_name(col, items[0].strip()):
                            answer_cols.add((col, items[1].strip()))

            if label_cols == answer_cols:
                cate_correct = True
    else:
        if 'no' in sentences[0].lower():
            cate_correct = True

    return [cate_correct], [], [], []

def evaluate_shape_cluster(answer: str, attribute: dict, cols: List[str]):
    cate_correct = 0.0
    num_correct = []

    label_cols = set(attribute['cols'])
    answer_cols = set()

    sentences = split_period_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [], [0.0], [{}]

    # Split into facts
    for fact in split_period_sentences(answer)[0].split(','):
        fact = fact.strip()
        for col in cols:
            if match_metric_name(col, fact):
                answer_cols.add(col)

    # Calculate f1-score for label and answer
    tp = len(label_cols & answer_cols)
    fp = len(answer_cols - label_cols)
    fn = len(label_cols - answer_cols)
    if tp + fp + fn > 0:
        cate_correct = 2 * tp / (2 * tp + fp + fn)

    return [cate_correct], [], [], []

def evaluate_local_cluster(answer: str, attribute: dict, cols: List[str]):
    cate_correct = 0.0
    num_correct = []

    label_cols = set(zip(attribute['cols'], [i[1] for i in attribute['col_idx']]))
    answer_cols = set()

    sentences = split_period_sentences(answer)

    if len(sentences) == 0:
        return [0.0], [], [0.0], [{}]

    # Split into facts
    for fact in split_period_sentences(answer)[0].split(';'):
        items = fact.strip().split(',')
        if len(items) == 2:
            for col in cols:
                if match_metric_name(col, items[0].strip()):
                    answer_cols.add((col, items[1].strip()))

    # Calculate f1-score for label and answer
    tp = len(label_cols & answer_cols)
    fp = len(answer_cols - label_cols)
    fn = len(label_cols - answer_cols)
    if tp + fp + fn > 0:
        cate_correct = 2 * tp / (2 * tp + fp + fn)

    return [cate_correct], [], [], []

def evaluate_MCQ2(answer: str, attribute: str, cols: List[str]):
    if match_metric_name(attribute, answer):
        reason_correct = 1.0
    else:
        reason_correct = 0.0
    return [], [], [reason_correct], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_pattern_recognition(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_numerical_judgement(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_causal(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_uts_reason_calculation(answer: str, attribute: float, cols: List[str]):
    result = 0.0
    try:
        response_value = float(answer)
        label_value = float(attribute)
        if label_value > 1e-2:
            result = max(0.0, min(1.0, 1.0 - abs(response_value - label_value) / abs(label_value)))
        else:
            result = 1.0 if abs(response_value - label_value) < 1e-2 else 0.0
    except Exception as err:
        result = 0.0

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_pattern_recognition(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_numerical_judgement(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_causal(answer: str, attribute: str, cols: List[str]):
    result = float(answer.lower() == attribute.lower())

    return [], [], [result], [{'label': attribute, 'response': answer}]

def evaluate_compare_choice(answer: str, attribute: str, cols: List[str]):
    if match_metric_name(answer, attribute):
        reason_correct = 1.0
    else:
        reason_correct = 0.0

    return [], [], [reason_correct], [{'label': attribute, 'response': answer}]

def evaluate_mts_reason_calculation(answer: str, attribute: float, cols: List[str]):
    result = 0.0
    try:
        response_value = float(answer)
        label_value = float(attribute)
        if label_value > 1e-2:
            result = max(0.0, min(1.0, 1.0 - abs(response_value - label_value) / abs(label_value)))
        else:
            result = 1.0 if abs(response_value - label_value) < 1e-2 else 0.0
    except Exception as err:
        result = 0.0

    return [], [], [result], [{'label': attribute, 'response': answer}]

def ability_type_to_func(ability_type: str):
    return eval("evaluate_" + ability_type.replace('-', '_'))

def evaluate_qa(answer: str, sample: dict):
    # Extract from deep thinking
    if "</think>" in answer:
        answer = answer.split("</think>")[-1].strip()
    answer = answer.replace('<answer>', '').replace('</answer>', '').strip()

    # Get ability types
    ability_types = sample['ability_types']

    # Check if it is the single line (rlvr) model
    if '\\answer{}' in sample['question']:
        # It is RLVR model! We should parse the answer in \answer{} format
        logger.warning("[RLVR] Detected RLVR format answer.")
        if '\\answer{' in answer:
            answer = re.findall(r'\\answer\{(.*)\}', answer, re.DOTALL)[0]
        matched_cnt = 1
        answer_list = [answer]
        num_questions = 1
    else:
        # Use the regular parsing
        answer_list = re.findall(r'(?:^|\n).*?\d+\.\s*(.*?)(?=\n.*?\d+\.|$)', answer, re.MULTILINE | re.DOTALL)
        num_questions = len(sample['attributes'])
        matched_cnt = min(len(answer_list), num_questions)

        # Try match
        if matched_cnt < num_questions and matched_cnt == 1:
            idx_pos = []
            for idx in range(1, num_questions + 1):
                sub_str = f"{idx}. "
                if sub_str in answer:
                    idx_pos.append(answer.index(sub_str))
                else:
                    break

            if len(idx_pos) == num_questions:
                # Matched
                idx_pos.append(len(answer))
                answer_list = [answer[idx_pos[i] + len(f"{i+1}. "):idx_pos[i + 1]] for i in range(num_questions)]
                matched_cnt = min(len(answer_list), num_questions)
                logger.warning("[TRY MATCH]---------------------------")
                logger.warning(f"[TRY MATCH] {answer_list}")
                logger.warning("[TRY MATCH]---------------------------")
        elif num_questions == 1 and matched_cnt == 1 and len(answer_list[0].strip()) == 0:
            # Empty answer
            answer_list[0] = answer.rstrip()
            logger.warning("[TRY MATCH]---------------------------")
            logger.warning(f"[TRY MATCH] {answer_list}")
            logger.warning("[TRY MATCH]---------------------------")

    result = {}

    for i in range(len(ability_types)):
        ability = ability_types[i]
        evaluate_func = ability_type_to_func(ability)
        cur_answer = answer_list[i] if i < matched_cnt else ""
        cate_correct, num_correct, reason_correct, reason_detail = evaluate_func(cur_answer, sample['attributes'][i], sample['cols'])

        if ability in result:
            # Extent current result to existed
            cate_correct = result[ability][0] + cate_correct
            num_correct = result[ability][1] + num_correct
            reason_correct = result[ability][2] + reason_correct
            reason_detail = result[ability][3] + reason_detail
        result[ability] = (cate_correct, num_correct, reason_correct, reason_detail)  

    return result

def process_sample(args):
    idx, sample, generated_answer = args
    try:
        label = sample['answer']

        # find idx in generated answer
        pos = -1
        for i, item in enumerate(generated_answer):
            if item['idx'] == idx:
                pos = i
                break
        answer = generated_answer[pos]['response']
        evaluation_result = evaluate_qa(answer, sample)

        return {
            'idx': idx,
            'label': label,
            'response': answer,
            'evaluation': evaluation_result
        }
    except Exception as err:
        logger.error(err)
        traceback.print_exc()
        return None

def evaluate_batch_qa(dataset, generated_answer, EXP, num_workers=8):
    detailed_result = []
    ability_result = {'categorical': {}, 'numerical': {}, 'reason': {}}
    overall_result = {'categorical': [], 'numerical': [], 'reason': []}

    # Evaluation
    logger.info("Start evaluation...")

    with Pool(processes=num_workers) as pool:
        results = list(tqdm(pool.imap(process_sample, [(idx, dataset[idx], generated_answer) for idx in range(len(dataset))]), total=len(dataset)))

    for result in results:
        if result is None:
            continue

        detailed_result.append(result)
        evaluation_result = result['evaluation']

        # Parse naive result
        for ability, (cate_correct, num_correct, reason_correct, reason_detail) in evaluation_result.items():
            ability_result['categorical'].setdefault(ability, [])
            ability_result['numerical'].setdefault(ability, [])
            ability_result['reason'].setdefault(ability, [])

            ability_result['categorical'][ability].extend(cate_correct)
            ability_result['numerical'][ability].extend(num_correct)
            ability_result['reason'][ability].extend(reason_correct)

            overall_result['categorical'].extend(cate_correct)
            overall_result['numerical'].extend(num_correct)
            overall_result['reason'].extend(reason_correct)

    # Calculate tokens
    total_tokens = 0
    for item in generated_answer:
        total_tokens += item.get('num_tokens', 0)

    logger.info(f"[RESULT] -----------------------------------------------------------------")
    logger.info(f"[RESULT] Experiment: {EXP}")
    logger.info(f"[RESULT] Total: {len(dataset)}, Success Evaluation: {len(detailed_result)}")
    logger.info(f"[RESULT] Detailed Categorical: {[(k, round(float(np.mean(v)), 4)) for (k, v) in ability_result['categorical'].items()]}")
    logger.info(f"[RESULT] Detailed Numerical: {[(k, round(float(np.mean(v)), 4)) for (k, v) in ability_result['numerical'].items()]}")
    logger.info(f"[RESULT] Detailed Reason: {[(k, round(float(np.mean(v)), 4)) for (k, v) in ability_result['reason'].items()]}")
    logger.info(f"[RESULT] Overall Categorical: {round(float(np.mean(overall_result['categorical'])), 4)}; Overall Numerical: {round(float(np.mean(overall_result['numerical'])), 4)}; Overall Reason: {round(float(np.mean(overall_result['reason'])), 4)}")
    logger.info(f"[RESULT] Consumed tokens: {total_tokens}")
    logger.info(f"[RESULT] -----------------------------------------------------------------")

    # Save Result
    json.dump(detailed_result, open(f"exp/{EXP}/detailed_result.json", "wt"), ensure_ascii=False, indent=4)
    json.dump({
        'detail_categorical': dict((k, round(float(np.mean(v)), 4)) for (k, v) in ability_result['categorical'].items()),
        'detail_numerical': dict((k, round(float(np.mean(v)), 4)) for (k, v) in ability_result['numerical'].items()),
        'detail_reason': dict((k, round(float(np.mean(v)), 4)) for (k, v) in ability_result['reason'].items()),
        'overall_categorical': round(float(np.mean(overall_result['categorical'])), 4),
        'overall_numerical': round(float(np.mean(overall_result['numerical'])), 4),
        'overall_reason': round(float(np.mean(overall_result['reason'])), 4),
        'consumed_tokens': total_tokens,
    }, open(f"exp/{EXP}/result.json", "wt"), ensure_ascii=False, indent=4)
