import re
import numpy as np
import json
import os
from collections import defaultdict
from tqdm import tqdm

from examples.scripts.values_utils import compare_text


def extract_title(code):
    """
    Extract titles from code
    
    Parameters:
    code (str): code string
    
    Returns:
    list: all extracted titles
    """
    titles = []
    
    # Extract global titles (fig.suptitle, plt.suptitle)
    suptitle_patterns = [
        r'fig\.suptitle\s*\(\s*([\'"])(.*?)\1',
        r'plt\.suptitle\s*\(\s*([\'"])(.*?)\1'
    ]
    
    for pattern in suptitle_patterns:
        matches = re.findall(pattern, code)
        for match in matches:
            titles.append(match[1])
    
    # Extract regular titles (plt.title, ax.set_title, axes[i].set_title, etc.)
    title_patterns = [
        r'plt\.title\s*\(\s*([\'"])(.*?)\1',
        r'ax\.set_title\s*\(\s*([\'"])(.*?)\1',
        r'ax\d+\.set_title\s*\(\s*([\'"])(.*?)\1',
        # One-dimensional indexing
        r'axes\[\d+\]\.set_title\s*\(\s*([\'"])(.*?)\1',
        r'axs\[\d+\]\.set_title\s*\(\s*([\'"])(.*?)\1',
        r'ax\[\d+\]\.set_title\s*\(\s*([\'"])(.*?)\1',
        # Two-dimensional indexing
        r'axs\[\d+\s*,\s*\d+\]\.set_title\s*\(\s*([\'"])(.*?)\1',
        r'axes\[\d+\s*,\s*\d+\]\.set_title\s*\(\s*([\'"])(.*?)\1',
    ]
    
    for pattern in title_patterns:
        matches = re.findall(pattern, code)
        for match in matches:
            titles.append(match[1])
    
    return titles

def compare_title(completion_titles, answer_titles):
    """
    Compare the consistency between two sets of titles
    
    Parameters:
    completion_titles (list): titles from model-generated code
    answer_titles (list): titles from reference answer code
    
    Returns:
    float: consistency score, ranging from 0 to 1
    """
    # If there are no titles in the answer, score is 1.0
    if not answer_titles:
        return 1.0
    
    # If there are no titles in the completion but there are in the answer, score is 0.0
    if not completion_titles:
        return 0.0
    
    # For each title in the answer, find the best matching title in the completion
    title_scores = []
    remaining_completion_titles = completion_titles.copy()
    
    for answer_title in answer_titles:
        if not remaining_completion_titles:
            title_scores.append(0.0)
            continue
            
        best_score = 0.0
        best_match_index = -1
        
        for i, completion_title in enumerate(remaining_completion_titles):
            score = compare_text(completion_title, answer_title)
            if score > best_score:
                best_score = score
                best_match_index = i
        
        title_scores.append(best_score)
        if best_match_index >= 0:
            remaining_completion_titles.pop(best_match_index)
    
    # Calculate average score
    avg_score = sum(title_scores) / len(title_scores) if title_scores else 0.0
    
    return avg_score

def extract_labels(code):
    """
    Extract axis labels (xlabel, ylabel, zlabel) from code
    
    Parameters:
    code (str): code string
    
    Returns:
    dict: dictionary containing x-axis, y-axis, and z-axis labels
    """
    labels = {
        'xlabels': [],  # x-axis labels
        'ylabels': [],  # y-axis labels
        'zlabels': []   # z-axis labels
    }
    
    # Define possible representations of axis objects
    axis_prefixes = [
        'plt',                  # plt.xlabel()
        'ax',                   # ax.set_xlabel()
        'axes\\[\\d+\\]',       # axes[0].set_xlabel()
        'axs\\[\\d+\\]',        # axs[0].set_xlabel()
        'ax\\[\\d+\\]',         # ax[0].set_xlabel()
        'axes\\[\\d+\\s*,\\s*\\d+\\]',  # axes[0,0].set_xlabel()
        'axs\\[\\d+\\s*,\\s*\\d+\\]',   # axs[0,0].set_xlabel()
        'ax\\d+',               # ax1.set_xlabel()
    ]
    
    # Extract x-axis labels
    for prefix in axis_prefixes:
        # plt.xlabel form
        if prefix == 'plt':
            pattern = r'plt\.xlabel\s*\(\s*([\'"])(.*?)\1'
        else:
            # ax.set_xlabel form
            pattern = f'{prefix}\\.set_xlabel\\s*\\(\\s*([\'"])(.*?)\\1'
        
        matches = re.findall(pattern, code)
        for match in matches:
            labels['xlabels'].append(match[1])
    
    # Extract y-axis labels
    for prefix in axis_prefixes:
        # plt.ylabel form
        if prefix == 'plt':
            pattern = r'plt\.ylabel\s*\(\s*([\'"])(.*?)\1'
        else:
            # ax.set_ylabel form
            pattern = f'{prefix}\\.set_ylabel\\s*\\(\\s*([\'"])(.*?)\\1'
        
        matches = re.findall(pattern, code)
        for match in matches:
            labels['ylabels'].append(match[1])
    
    # Extract z-axis labels
    for prefix in axis_prefixes:
        # plt.zlabel form
        if prefix == 'plt':
            pattern = r'plt\.zlabel\s*\(\s*([\'"])(.*?)\1'
        else:
            # ax.set_zlabel form
            pattern = f'{prefix}\\.set_zlabel\\s*\\(\\s*([\'"])(.*?)\\1'
        
        matches = re.findall(pattern, code)
        for match in matches:
            labels['zlabels'].append(match[1])
    
    return labels

def compare_labels(completion_labels, answer_labels):
    """
    Compare the consistency between two sets of axis labels
    
    Parameters:
    completion_labels (dict): axis labels from model-generated code
    answer_labels (dict): axis labels from reference answer code
    
    Returns:
    float: consistency score, ranging from 0 to 1
    """
    # Calculate score for each type of label
    scores = {}
    valid_label_types = []
    
    for label_type in ['xlabels', 'ylabels', 'zlabels']:
        answer_list = answer_labels[label_type]
        completion_list = completion_labels[label_type]
        
        # If there are no labels of this type in the answer, skip, don't include in average
        if not answer_list:
            continue
        
        # Record valid label types
        valid_label_types.append(label_type)
        
        # If there are no labels of this type in the completion but there are in the answer, score is 0
        if not completion_list:
            scores[label_type] = 0.0
            continue
        
        # For each label in the answer, find the best matching label in the completion
        label_scores = []
        remaining_completion_labels = completion_list.copy()
        
        for answer_label in answer_list:
            if not remaining_completion_labels:
                label_scores.append(0.0)
                continue
                
            best_score = 0.0
            best_match_index = -1
            
            for i, completion_label in enumerate(remaining_completion_labels):
                score = compare_text(completion_label, answer_label)
                if score > best_score:
                    best_score = score
                    best_match_index = i
            
            label_scores.append(best_score)
            if best_match_index >= 0:
                remaining_completion_labels.pop(best_match_index)
        
        # Calculate average score
        scores[label_type] = sum(label_scores) / len(label_scores) if label_scores else 0.0
    
    # Calculate total score (only considering label types that exist in the answer)
    if not valid_label_types:
        # If there are no labels in the answer, score is 1.0
        return 1.0
    
    total_score = sum(scores[label_type] for label_type in valid_label_types) / len(valid_label_types)

    return total_score


if __name__ == "__main__":
    pred = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_c_perpendicular = [0.23, 0.23, 0.9, 0.9, 1.6]\nvo_i = [1.77, np.nan, 2.42, 2.0, np.nan] # Use NaN for missing data\nvo_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nvo_iii = [1.10, np.nan, 1.36, 1.7, np.nan]\n\n# Create indices for plotting\nx_indices = np.arange(len(methods))\n\n# Create figure and subplots\nfig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot Ti barrier energies\naxes[0].plot(x_indices, ti_c_parallel, marker='o', linestyle='-', color='skyblue', label='Ti$_{\\mathrm{i}}$ c$_{\\parallel}$')\naxes[0].plot(x_indices, ti_c_perpendicular, marker='s', linestyle='--', color='salmon', label='Ti$_{\\mathrm{i}}$ c$_{\\perp}$')\n\n# Add values as text labels for Ti\nfor i, txt in enumerate(ti_c_parallel):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_parallel[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_c_perpendicular):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_perpendicular[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[0].set_ylabel('Barrier Energy (eV)')\naxes[0].set_title('Barrier energies of Ti$_{\\mathrm{i}}$ ($c_{\\parallel}$ and $c_{\\perp}$) and V$_{\\mathrm{O}}$ (I, II, and III) paths', fontsize=14)\naxes[0].legend()\naxes[0].grid(True, linestyle='--', alpha=0.6)\n\n# Plot VO barrier energies\naxes[1].plot(x_indices, vo_i, marker='^', linestyle='-', color='lightgreen', label='V$_{\\mathrm{O}}$ I')\naxes[1].plot(x_indices, vo_ii, marker='v', linestyle='--', color='orchid', label='V$_{\\mathrm{O}}$ II')\naxes[1].plot(x_indices, vo_iii, marker='d', linestyle=':', color='gold', label='V$_{\\mathrm{O}}$ III')\n\n# Add values as text labels for VO, handling NaN\nfor i, txt in enumerate(vo_i):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_i[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_ii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_ii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_iii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_iii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[1].set_ylabel('Barrier Energy (eV)')\naxes[1].set_xticks(x_indices)\naxes[1].set_xticklabels(methods, rotation=45, ha='right')\naxes[1].legend()\naxes[1].grid(True, linestyle='--', alpha=0.6)\n\n# Set overall title and adjust layout\nfig.suptitle('Calculated and Experimental Barrier Energies', fontsize=16, y=1.02)\nplt.tight_layout()\n\n# Show the plot\nplt.show()\n\n```"
    gt = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_i_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_i_c_perp = [0.225, 0.23, 0.90, 0.9, 1.6]\nv_o_i = [1.77, np.nan, 2.42, 2.0, np.nan]\nv_o_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nv_o_iii = [1.1, np.nan, 1.36, 1.7, np.nan]\n\n# Define x-axis positions for each method\nx = np.arange(len(methods))\n\n# Create subplots for different barrier types\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot for Ti_i barrier energies\nax1.plot(x, ti_i_c_parallel, label='Ti$_\\mathrm{i}$ $c_\\parallel$', marker='o', linestyle='-', color='skyblue')\nax1.plot(x, ti_i_c_perp, label='Ti$_\\mathrm{i}$ $c_\\perp$', marker='s', linestyle='--', color='salmon')\nax1.set_ylabel('Barrier Energy (eV)')\nax1.set_title('Barrier energies of Ti$_\\mathrm{i}$ ($c_\\parallel$ and $c_\\perp$) and V$_\\mathrm{O}$ (I, II, and III) paths')\nax1.legend(loc='upper left')\nax1.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax1\nfor i, txt in enumerate(ti_i_c_parallel):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i],txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_i_c_perp):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Plot for V_O barrier energies\nax2.plot(x, v_o_i, label='V$_\\mathrm{O}$ I', marker='^', linestyle='-', color='lightgreen')\nax2.plot(x, v_o_ii, label='V$_\\mathrm{O}$ II', marker='v', linestyle='--', color='orchid')\nax2.plot(x, v_o_iii, label='V$_\\mathrm{O}$ III', marker='d', linestyle=':', color='gold')\nax2.set_ylabel('Barrier Energy (eV)')\nax2.set_xticks(x)\nax2.set_xticklabels(methods, rotation=45, ha='right')\nax2.legend(loc='upper left')\nax2.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax2\nfor i, txt in enumerate(v_o_i):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_ii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_iii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Add a title to the entire figure\nfig.suptitle('Barrier Energies of Ti$_\\mathrm{i}$ and V$_\\mathrm{O}$ Paths by Various Methods', fontsize=14, y=1.02)\n\nplt.tight_layout()\nplt.show()\n\n```"
    
    print(f"pred: {pred}")
    print(f"gt: {gt}")
    
    pred_title = extract_title(pred)
    print(f"pred processed data: {pred_title}")

    gt_title = extract_title(gt)
    print(f"gt processed data: {gt_title}")
    
    score = compare_title(pred_title, gt_title)
    print(f"score: {score}")

    pred_labels = extract_labels(pred)
    print(f"pred processed data: {pred_labels}")

    gt_labels = extract_labels(gt)
    print(f"gt processed data: {gt_labels}")

    score = compare_labels(pred_labels, gt_labels)
    print(f"score: {score}")
