import json
import re
import ast
import io
import pandas as pd
import numpy as np
import argparse
import logging
from typing import List, Dict, Any, Optional, Tuple, Union
import math
from thefuzz import fuzz
from mathruler.grader import grade_answer


def parse_csv_string_to_dataframe(csv_string: str) -> Optional[pd.DataFrame]:
    if not csv_string or not isinstance(csv_string, str):
        return None
    try:
        processed_string = csv_string.strip().replace('\\t', '\t').replace('\\n', '\n')
        csv_file = io.StringIO(processed_string)
        df = pd.read_csv(csv_file, sep='\t', engine='python', skipinitialspace=True, on_bad_lines='warn')
        df.columns = df.columns.str.strip()
        for col in df.columns:
            if pd.api.types.is_string_dtype(df[col]): df[col] = df[col].str.strip()
        df = df.infer_objects()
        for col in df.columns:
            if df[col].dtype == 'object':
                try:
                    numeric_col = pd.to_numeric(df[col], errors='coerce')
                    if pd.api.types.is_numeric_dtype(numeric_col):
                        df[col] = numeric_col
                except (ValueError, TypeError):
                    pass
        return df
    except pd.errors.EmptyDataError:
        return pd.DataFrame()
    except Exception as e:
        return None

def extract_tag_content(tag: str, text: str) -> Optional[str]:
    if not text: return None
    match = re.search(rf"<{tag}(?: [^>]*)?>(.*?)</{tag}>", text, re.DOTALL | re.IGNORECASE)
    if not match and tag.lower() == 'answer':
         match = re.search(r"<answer(?: [^>]*)?>(.*?)/answer>", text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match and match.group(1) is not None else None

def extract_boxed_content(text: str) -> Optional[str]:
    if not text: return None
    match = re.search(r"\\boxed\{(.*?)\}", text, re.DOTALL)
    if match and match.group(1) is not None:
        return match.group(1).strip()
    return None

def extract_first_number(text: str) -> Optional[float]:
    if text is None:
        return None
    text = str(text).strip()
    if not text: 
        return None
    text_no_commas = text.replace(',', '')
    match = re.search(r'[-+]?\d*\.\d+|[-+]?\d+', text_no_commas)
    if match:
        number_str = match.group(0)
        try:
            value = float(number_str)
            return value
        except ValueError:
            return None
    else:
        cleaned_further = re.sub(r'[^\d\.\-eE]', '', text_no_commas)
        if cleaned_further != text_no_commas:
            match = re.search(r'[-+]?\d*\.\d+|[-+]?\d+', cleaned_further)
            if match:
                number_str = match.group(0)
                try:
                    value = float(number_str)
                    return value
                except ValueError:
                    return None
        return None

def compare_numbers(target_num: Optional[float],
                    prediction_num: Optional[float],
                    max_relative_change: float = 0.01) -> bool:
    if prediction_num is None or target_num is None or \
       math.isnan(prediction_num) or math.isnan(target_num):
        return False

    if abs(target_num) < 1e-9:
        is_correct = abs(prediction_num) < 1e-9
    else:
        if abs(target_num) == 0:
             is_correct = abs(prediction_num) < 1e-9 
        else:
            relative_change = abs(prediction_num - target_num) / abs(target_num)
            is_correct = relative_change <= max_relative_change
        logging.debug(f"compare_numbers: Target={target_num}, Pred={prediction_num}. Relative change={relative_change:.4f}. Tolerance={max_relative_change}. Correct: {is_correct}")

    return is_correct

def exact_string_match(target_str: str, prediction_str: str) -> bool:
    pred_clean = str(prediction_str).strip().lower() if prediction_str is not None else ""
    target_clean = str(target_str).strip().lower() if target_str is not None else ""

    pred_clean = re.sub(r'[.!?]$', '', pred_clean).strip()
    target_clean = re.sub(r'[.!?]$', '', target_clean).strip()

    pred_clean = re.sub(r'[^\w\s]', '', pred_clean)
    target_clean = re.sub(r'[^\w\s]', '', target_clean)

    is_correct = (pred_clean == target_clean)
    logging.debug(f"exact_string_match: Target='{target_clean}', Pred='{pred_clean}'. Correct: {is_correct}")
    return is_correct

def calculate_qa_accuracy(predicted_answer: Optional[str],
                          ground_truth_answer: str,
                          numeric_tolerance: float = 0.01) -> float:

    if ground_truth_answer is None:
        return 0.0
    true_label_str = str(ground_truth_answer).strip()

    if predicted_answer is None or str(predicted_answer).strip() == "":
        return 0.0
    pred_extracted_str = str(predicted_answer).strip()

    is_correct = False
    eval_type = "Unknown"

    if re.search(r'\d', true_label_str):
        eval_type = "Numeric Comparison"
        logging.debug(f"QA Accuracy - Type: {eval_type}")

        target_num = extract_first_number(true_label_str)
        pred_num = extract_first_number(pred_extracted_str)

        logging.debug(f"QA Accuracy - Extracted Nums: Target={target_num}, Pred={pred_num}")
        if target_num is not None and pred_num is not None:
             is_correct = compare_numbers(target_num, pred_num, numeric_tolerance)
        else:
             is_correct = False 
             logging.debug(f"QA Accuracy - Numeric comparison failed: Could not extract number from both.")

    else: 
        eval_type = "Textual Comparison"
        logging.debug(f"QA Accuracy - Type: {eval_type}")
        is_correct = exact_string_match(true_label_str, pred_extracted_str)

    logging.debug(f"QA Accuracy - Result: Correct={is_correct} (GT='{true_label_str}', Pred='{pred_extracted_str}')")
    return 1.0 if is_correct else 0.0

class DataFrameDataExtractor(ast.NodeVisitor):
    def __init__(self):
        self.literal_assignments: Dict[str, Any] = {}
        self.chart_data_constructor_arg: Optional[Any] = None
        self.chart_data_assignment_node: Optional[ast.Assign] = None
        self.found_chart_data_assign = False

    def visit_Assign(self, node: ast.Assign):
        if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            var_name = node.targets[0].id
            try:
                value = ast.literal_eval(node.value)
                if isinstance(value, (dict, list, tuple)):
                    self.literal_assignments[var_name] = value
            except (ValueError, SyntaxError): pass
            except Exception as e: logging.warning(f"{e}")

            if var_name == 'chart_data':
                 self.chart_data_assignment_node = node
                 self.found_chart_data_assign = True
        self.generic_visit(node)

    def process_chart_data_assignment(self):
        if not self.chart_data_assignment_node:
            return
        node = self.chart_data_assignment_node
        if (isinstance(node.value, ast.Call) and
            isinstance(node.value.func, ast.Attribute) and node.value.func.attr == 'DataFrame' and
            isinstance(node.value.func.value, ast.Name) and node.value.func.value.id == 'pd'):
            if node.value.args:
                data_arg_node = node.value.args[0]
                try:
                    self.chart_data_constructor_arg = ast.literal_eval(data_arg_node)
                    return
                except (ValueError, SyntaxError):
                    if isinstance(data_arg_node, ast.Name):
                        arg_var_name = data_arg_node.id
                        if arg_var_name in self.literal_assignments:
                            self.chart_data_constructor_arg = self.literal_assignments[arg_var_name]
                            return
                except Exception as e: logging.error(f"{e}")
            elif any(kw.arg == 'data' for kw in node.value.keywords):
                data_kw_node = next((kw.value for kw in node.value.keywords if kw.arg == 'data'), None)
                if data_kw_node:
                    try:
                        self.chart_data_constructor_arg = ast.literal_eval(data_kw_node)
                        return
                    except (ValueError, SyntaxError):
                        if isinstance(data_kw_node, ast.Name):
                            arg_var_name = data_kw_node.id
                            if arg_var_name in self.literal_assignments:
                                self.chart_data_constructor_arg = self.literal_assignments[arg_var_name]
                                return
                    except Exception as e: logging.error(f"{e}")

def extract_python_code_block(text: str) -> Optional[str]:
    if not text: return None
    match = re.search(r"```python\s*\n(.*?)\n```", text, re.DOTALL)
    if match and match.group(1) is not None:
        return match.group(1).strip()
    match_inline = re.search(r"```python(.*?)```", text, re.DOTALL)
    if match_inline and match_inline.group(1) is not None:
         return match_inline.group(1).strip()
    return None

def extract_dataframe_constructor_data(code_string: str) -> Optional[Any]:
    if not code_string: return None
    try:
        tree = ast.parse(code_string)
        extractor = DataFrameDataExtractor()
        extractor.visit(tree)
        extractor.process_chart_data_assignment()
        return extractor.chart_data_constructor_arg
    except SyntaxError as se:
        return None
    except Exception as e:
        return None

def construct_dataframe_from_extracted(extracted_data: Any) -> Optional[pd.DataFrame]:
    if extracted_data is None: return None
    try:
        if isinstance(extracted_data, dict):
            lengths = [len(v) for v in extracted_data.values() if isinstance(v, (list, tuple))]
            gen_df = pd.DataFrame(extracted_data)
        elif isinstance(extracted_data, (list, tuple)):
             if all(isinstance(row, (list, tuple)) for row in extracted_data):
                 row_lengths = [len(row) for row in extracted_data]
             gen_df = pd.DataFrame(extracted_data)
        else:
            gen_df = pd.DataFrame(extracted_data)

        for col in gen_df.columns:
             try:
                 original_dtype = gen_df[col].dtype
                 numeric_col = pd.to_numeric(gen_df[col], errors='coerce')
                 if pd.api.types.is_numeric_dtype(numeric_col):
                    gen_df[col] = numeric_col
             except Exception as e: logging.error(f"{e}")
        return gen_df
    except ValueError as ve: logging.error(f"{ve}"); return None
    except Exception as e: logging.error(f"{e}"); return None


def normalize_name(name: Union[str, int, float]) -> str:
    return re.sub(r'[_\s()%-.\[\]{}]', '', str(name).lower().strip())

def compare_values(gen_val: Any, ref_val: Any) -> bool:
    NUMERICAL_TOLERANCE = 1e-6
    if pd.isna(gen_val) and pd.isna(ref_val): return True
    if pd.isna(gen_val) or pd.isna(ref_val): return False
    try:
        return grade_answer(str(gen_val).strip(), str(ref_val).strip())
    except Exception as e:
        pred_norm = str(gen_val).strip().lower()
        gt_norm = str(ref_val).strip().lower()
        if pred_norm == gt_norm: return True
        try: return np.isclose(float(pred_norm), float(gt_norm), atol=NUMERICAL_TOLERANCE, rtol=1e-3)
        except (ValueError, TypeError): return False

def calculate_dataframe_comparison_reward(
    gen_df: Optional[pd.DataFrame],
    ref_df: Optional[pd.DataFrame]
) -> Dict[str, Any]:
    FUZZY_MATCH_THRESHOLD = 50
    ACCURACY_WEIGHT = 0.7
    COMPLETENESS_COL_WEIGHT = 0.2
    COMPLETENESS_ROW_WEIGHT = 0.1
    FUZZ_AVAILABLE = True
    
    results = {'combined_score': 0.0, 'completeness_col_score': 0.0, 'completeness_row_score': 0.0, 'average_accuracy_score': 0.0, 'comparison_details': [], 'error_message': None}
    if gen_df is None:
        if ref_df is not None and not ref_df.empty:
             for ref_col in ref_df.columns: results['comparison_details'].append({'ref_col': ref_col, 'gen_col': None, 'match_score': 0, 'row_match': False, 'accuracy': 0.0, 'details': [(i, val, "N/A", False) for i, val in enumerate(ref_df[ref_col])]})
        return results
    if ref_df is None or ref_df.empty:
        if not gen_df.empty:
             for gen_col in gen_df.columns: results['comparison_details'].append({'ref_col': "N/A", 'gen_col': gen_col, 'match_score': 0, 'row_match': False, 'accuracy': 0.0, 'details': [(i, "N/A", val, False) for i, val in enumerate(gen_df[gen_col])]})
        return results

    rows_match = len(gen_df) == len(ref_df)
    results['completeness_row_score'] = 1.0 if rows_match else 0.0

    ref_cols_original = ref_df.columns.tolist()
    gen_cols_original = gen_df.columns.tolist()

    gen_cols_norm: Dict[str, str] = {}
    for col in gen_cols_original:
        norm = normalize_name(col)
        gen_cols_norm[norm] = col

    matched_gen_cols_norm = set()
    best_matches: Dict[str, Tuple[Optional[str], int]] = {}

    for original_ref_col in ref_cols_original:
        norm_ref_col = normalize_name(original_ref_col)

        best_score = -1
        best_gen_col_norm = None

        if norm_ref_col in gen_cols_norm and norm_ref_col not in matched_gen_cols_norm:
            best_gen_col_norm, best_score = norm_ref_col, 100
        else:
            for norm_gen_col in gen_cols_norm:
                if norm_gen_col not in matched_gen_cols_norm:
                    score = fuzz.partial_ratio(norm_ref_col, norm_gen_col)
                    if score > best_score:
                        best_score, best_gen_col_norm = score, norm_gen_col

            if best_score < FUZZY_MATCH_THRESHOLD:
                best_gen_col_norm = None

        if best_gen_col_norm is not None:
            original_best_gen_col = gen_cols_norm[best_gen_col_norm]
            best_matches[original_ref_col] = (original_best_gen_col, best_score)
            matched_gen_cols_norm.add(best_gen_col_norm)
        else:
            best_matches[original_ref_col] = (None, best_score if best_score != -1 else 0)

    num_matched_cols = sum(1 for gc, _ in best_matches.values() if gc is not None)
    num_ref_cols = len(ref_cols_original)
    results['completeness_col_score'] = round(num_matched_cols / num_ref_cols, 4) if num_ref_cols > 0 else (1.0 if not gen_cols_original else 0.0)

    total_col_accuracy_sum = 0.0
    num_compared_cols = 0

    for original_ref_col in ref_cols_original:
        original_gen_col, match_score = best_matches[original_ref_col]
        col_detail = {
            'ref_col': original_ref_col,
            'gen_col': original_gen_col,
            'match_score': match_score,
            'row_match': rows_match,
            'accuracy': 0.0,
            'details': []
        }
        if original_gen_col is not None:
            num_compared_cols += 1
            ref_series = ref_df[original_ref_col]
            gen_series = gen_df[original_gen_col]
            correct_count = 0
            comparisons_in_col = 0
            max_rows_to_compare = min(len(ref_series), len(gen_series))

            for i in range(max_rows_to_compare):
                try:
                    ref_val = ref_series.iloc[i]
                    gen_val = gen_series.iloc[i]
                    is_match = compare_values(gen_val, ref_val)
                    col_detail['details'].append((i, ref_val, gen_val, is_match))
                    if is_match:
                        correct_count += 1
                    comparisons_in_col += 1
                except IndexError as ie:
                    col_detail['details'].append((i, "Error", "Error", False))
                    break

            col_accuracy = correct_count / comparisons_in_col if comparisons_in_col > 0 else 0.0
            col_detail['accuracy'] = round(col_accuracy, 4)
            total_col_accuracy_sum += col_accuracy
        else:
            for i, val in enumerate(ref_df[original_ref_col]):
                col_detail['details'].append((i, val, "N/A", False))

        results['comparison_details'].append(col_detail)

    results['average_accuracy_score'] = round(total_col_accuracy_sum / num_compared_cols, 4) if num_compared_cols > 0 else 0.0
    results['combined_score'] = round(
        COMPLETENESS_COL_WEIGHT * results['completeness_col_score'] +
        COMPLETENESS_ROW_WEIGHT * results['completeness_row_score'] +
        ACCURACY_WEIGHT * results['average_accuracy_score'], 6
    )

    logging.debug(
        f"Combined Score={results['combined_score']:.4f}, "
        f"Col Compl={results['completeness_col_score']:.4f}, "
        f"Row Compl={results['completeness_row_score']:.4f}, "
        f"Avg Acc={results['average_accuracy_score']:.4f}"
    )
    return results

def compute_score(predict_str: str, ground_truth_combined: str) -> Dict[str, float]:
    W_QA_ACCURACY = 0.8
    W_R_DATA = 0.15
    W_FORMAT = 0.05
    predict_str = predict_str.strip() if predict_str else ""
    ground_truth_combined = ground_truth_combined.strip() if ground_truth_combined else ""

    if "<programability>" in ground_truth_combined:
        programability = extract_tag_content("programability", ground_truth_combined)
    else:
        programability = ""
    if "<csv>" in ground_truth_combined:
        actual_gt_csv = extract_tag_content("csv", ground_truth_combined)
    else:
        actual_gt_csv=""
    if "<answer>" in ground_truth_combined:
        actual_gt_qa = extract_tag_content("answer", ground_truth_combined)
    else:
        actual_gt_qa=ground_truth_combined

    format_match = bool(re.search(r"\\boxed\{.*?\}", predict_str, re.DOTALL))
    format_score = 1.0 if format_match else 0.0

    predicted_boxed_answer = None
    if format_match:
        predicted_boxed_answer = extract_boxed_content(predict_str)
    accuracy_score = 0.0
    if actual_gt_qa is not None: 
        accuracy_score = calculate_qa_accuracy(predicted_boxed_answer, actual_gt_qa)
    
    decision_reward_score = 0.0
    has_code_tag = "<CODE>" in predict_str
    has_direct_tag = "<DIRECT>" in predict_str
    
    if has_code_tag:
        if programability == "yes":
            if accuracy_score == 1.0:
                decision_reward_score = 1.0
            else: 
                decision_reward_score = 0.5
        else:
            decision_reward_score = 0.0
    elif has_direct_tag:
        if programability == "no":
            if accuracy_score == 1.0:
                decision_reward_score = 1.0
            else: 
                decision_reward_score = 0.5
        else:
            decision_reward_score = 0.0
    
    code_ratio = 1.0 if has_code_tag else 0.0
    logging.debug(f"Code ratio: {code_ratio} (has_code_tag: {has_code_tag})")
            
    r_data_score, completeness_col, completeness_row, avg_accuracy = 0.0, 0.0, 0.0, 0.0
    predicted_code = extract_python_code_block(predict_str) 

    if not predicted_code:
        logging.debug("r_data = 0")
    else:
        constructor_data = extract_dataframe_constructor_data(predicted_code)
        gen_df = construct_dataframe_from_extracted(constructor_data)
        ref_df = parse_csv_string_to_dataframe(actual_gt_csv)

        if gen_df is not None and ref_df is not None:
            df_comparison_results = calculate_dataframe_comparison_reward(gen_df, ref_df)
            r_data_score = df_comparison_results['combined_score']
            completeness_col = df_comparison_results['completeness_col_score']
            completeness_row = df_comparison_results['completeness_row_score']
            avg_accuracy = df_comparison_results['average_accuracy_score']
            df_comparison_error = df_comparison_results['error_message']
            logging.debug(f"R_data Score (DataFrame): {r_data_score:.4f}, Col Compl: {completeness_col:.4f}, Row Compl: {completeness_row:.4f}, Avg Acc: {avg_accuracy:.4f}")
        elif gen_df is None:
            logging.debug("r_data = 0")
        elif ref_df is None:
            logging.debug("r_data = 0")

    overall_score = (W_QA_ACCURACY * accuracy_score +
                     W_R_DATA * r_data_score +
                     W_FORMAT * format_score + 0.3 * decision_reward_score)
    logging.debug(f"Overall Score: {overall_score:.4f} (Acc: {accuracy_score*W_QA_ACCURACY:.4f}, R_data: {r_data_score*W_R_DATA:.4f}, Format: {format_score*W_FORMAT:.4f})")

    return {
        "overall": overall_score,
        "accuracy": accuracy_score,
        "r_data": r_data_score,
        "format": format_score,
        "decision_reward": decision_reward_score,
        "code_ratio": code_ratio,
        "r_data_completeness_col": completeness_col,
        "r_data_completeness_row": completeness_row,
        "r_data_avg_accuracy": avg_accuracy,
    }