import re
import numpy as np
import json
import os
from collections import defaultdict
from tqdm import tqdm


def extract_chart_type(code):
    """
    Extract chart types from code
    
    Parameters:
    code (str): code string
    
    Returns:
    list: list of extracted chart types
    """
    chart_types = []

    # Define possible representations of axis objects
    axis_prefixes = [
        'plt',                  # plt.plot()
        'ax',                   # ax.plot()
        'axes\\[\\d+\\]',       # axes[0].plot()
        'axs\\[\\d+\\]',        # axs[0].plot()
        'ax\\[\\d+\\]',         # ax[0].plot()
        'axes\\[\\d+\\s*,\\s*\\d+\\]',  # axes[0,0].plot()
        'axs\\[\\d+\\s*,\\s*\\d+\\]',   # axs[0,0].plot()
        'ax\\d+',               # ax1.plot()
    ]

    # Generate regex pattern for matching
    axis_pattern = '|'.join(axis_prefixes)

    # Check for 3D plots
    projection_3d_matches = re.findall(r'projection\s*=\s*[\'"]3d[\'"]', code)
    for _ in projection_3d_matches:
        chart_types.append("3D")

    # Check for Heatmap
    heatmap_matches = re.findall(r'sns\.heatmap\s*\(.*?\)', code, re.DOTALL)
    for _ in heatmap_matches:
        chart_types.append("Heatmap")

    # Check for Treemap
    treemap_matches = re.findall(r'squarify\.plot\s*\(.*?\)', code, re.DOTALL)
    for _ in treemap_matches:
        chart_types.append("Treemap")

    # Check for Graph (add only once, as import usually appears once)
    if "import networkx" in code:
        chart_types.append("Graph")

    # Check for Multi-axes
    twinx_matches = re.findall(f'({axis_pattern})\\.twinx\\s*\\(\\s*\\)', code)
    for _ in twinx_matches:
        chart_types.append("Multi-axes")
    
    # Check for Area
    stackplot_matches = re.findall(f'({axis_pattern})\\.stackplot\\s*\\(', code)
    for _ in stackplot_matches:
        chart_types.append("Area")
    
    # Check for Box
    boxplot_matches = re.findall(f'({axis_pattern})\\.boxplot\\s*\\(', code)
    for _ in boxplot_matches:
        chart_types.append("Box")
    
    # Check for Violin
    violinplot_matches = re.findall(f'({axis_pattern})\\.violinplot\\s*\\(', code)
    for _ in violinplot_matches:
        chart_types.append("Violin")
    
    # Check for Histogram
    hist_matches = re.findall(f'({axis_pattern})\\.hist\\s*\\(', code)
    for _ in hist_matches:
        chart_types.append("Histogram")
    
    # Check for Contour
    contourf_matches = re.findall(f'({axis_pattern})\\.contourf\\s*\\(', code)
    for _ in contourf_matches:
        chart_types.append("Contour")
    
    # Check for Quiver
    quiver_matches = re.findall(f'({axis_pattern})\\.quiver\\s*\\(', code)
    for _ in quiver_matches:
        chart_types.append("Quiver")
    
    # Check for Density
    has_kde = re.search(r'gaussian_kde\s*\(.*?\)', code, re.DOTALL)
    has_plot = re.search(f'({axis_pattern})\\.plot\\s*\\(.*?\\)', code, re.DOTALL)
    has_fill_between = re.search(f'({axis_pattern})\\.fill_between\\s*\\(.*?\\)', code, re.DOTALL)
    if has_kde and has_plot and has_fill_between:
        chart_types.append("Density")
    
    # Check for Errorbar and Errorpoint
    errorbar_matches = re.findall(f'({axis_pattern})\\.errorbar\\s*\\(', code)
    for match in errorbar_matches:
        # For each errorbar call, check if there's a fmt parameter starting with '-'
        match_pos = code.find(f'{match}.errorbar')
        if match_pos >= 0:
            # Extract substring from match_pos to the next closing parenthesis
            sub_code = code[match_pos:]
            bracket_count = 0
            end_pos = -1
            for i, char in enumerate(sub_code):
                if char == '(':
                    bracket_count += 1
                elif char == ')':
                    bracket_count -= 1
                    if bracket_count == 0:
                        end_pos = i
                        break
            
            if end_pos >= 0:
                errorbar_call = sub_code[:end_pos+1]
                fmt_param = re.search(r'fmt\s*=\s*[\'"]-(.*?)[\'"](.*?)', errorbar_call, re.DOTALL)
                if fmt_param:
                    chart_types.append("Errorbar")
                else:
                    chart_types.append("Errorpoint")
    
    # Check for Pie and Ring
    pie_matches = re.findall(f'({axis_pattern})\\.pie\\s*\\(', code)
    for match in pie_matches:
        # For each pie call, check if it includes wedgeprops and width
        match_pos = code.find(f'{match}.pie')
        if match_pos >= 0:
            # Extract substring from match_pos to the next closing parenthesis
            sub_code = code[match_pos:]
            bracket_count = 0
            end_pos = -1
            for i, char in enumerate(sub_code):
                if char == '(':
                    bracket_count += 1
                elif char == ')':
                    bracket_count -= 1
                    if bracket_count == 0:
                        end_pos = i
                        break
            
            if end_pos >= 0:
                pie_call = sub_code[:end_pos+1]
                if 'wedgeprops' in pie_call and 'width' in pie_call:
                    chart_types.append("Ring")
                else:
                    chart_types.append("Pie")

    # Check for Scatter and Bubble
    scatter_matches = re.findall(f'({axis_pattern})\\.scatter\\s*\\(', code)
    for match in scatter_matches:
        # For each scatter call, check the s parameter
        match_pos = code.find(f'{match}.scatter')
        if match_pos >= 0:
            # Extract substring from match_pos to the next closing parenthesis
            sub_code = code[match_pos:]
            bracket_count = 0
            end_pos = -1
            for i, char in enumerate(sub_code):
                if char == '(':
                    bracket_count += 1
                elif char == ')':
                    bracket_count -= 1
                    if bracket_count == 0:
                        end_pos = i
                        break
            
            if end_pos >= 0:
                scatter_call = sub_code[:end_pos+1]
                s_number_pattern = re.search(r's\s*=\s*(\d+(?:\.\d*)?|\.\d+)(?:\s*,|\s*\))', scatter_call)
                if s_number_pattern or not 's=' in scatter_call:
                    chart_types.append("Scatter")
                else:
                    chart_types.append("Bubble")
    
    # Check for Bar, Rose, Line and Radar
    # These types need to consider polar settings

    # First find all polar settings
    polar_settings = re.findall(r'projection\s*=\s*[\'"]polar[\'"]|polar\s*=\s*True|\.set_polar\s*\(\s*True\s*\)', code)
    has_polar = len(polar_settings) > 0
    
    # Check for Bar and Rose
    bar_matches = re.findall(f'({axis_pattern})\\.bar\\s*\\(', code)
    barh_matches = re.findall(f'({axis_pattern})\\.barh\\s*\\(', code)
    
    # If there are polar settings, all bar calls are Rose, otherwise they're Bar
    if has_polar:
        for _ in range(len(bar_matches) + len(barh_matches)):
            chart_types.append("Rose")
    else:
        for _ in range(len(bar_matches) + len(barh_matches)):
            chart_types.append("Bar")
    
    # Check for Line and Radar
    plot_matches = re.findall(f'({axis_pattern})\\.plot\\s*\\(', code)
    fill_matches = re.findall(f'({axis_pattern})\\.(fill|fill_between|polygon)\\s*\\(', code)
    
    # If there are polar settings and fill calls, it might be Radar
    if has_polar and len(fill_matches) > 0:
        # Assume each plot call pairs with a fill call to form a Radar
        radar_count = min(len(plot_matches), len(fill_matches))
        for _ in range(radar_count):
            chart_types.append("Radar")

    elif not has_polar:
        # If there are no polar settings, all plot calls are Line (unless already identified as Density)
        if "Density" not in chart_types:
            for _ in plot_matches:
                chart_types.append("Line")
    
    return chart_types


def compare_chart_type(completion_chart_types, answer_chart_types):
    """
    Compare the consistency of two sets of chart types, comparing one by one by index position
    
    Parameters:
    completion_chart_types (list): List of chart types from model-generated code
    answer_chart_types (list): List of chart types from reference answer code
    
    Returns:
    float: Consistency score, ranging from 0 to 1
    """
    # If there are no chart types in the answer, score is 1.0
    if not answer_chart_types:
        return 1.0
    
    # If there are no chart types in the completion but there are in the answer, score is 0.0
    if not completion_chart_types:
        return 0.0
    
    # Create a flag array indicating whether elements in completion have been matched
    matched_flags = [False] * len(completion_chart_types)
    matched_count = 0
    
    # For each type in answer, look for an unmatched same type in completion
    for answer_type in answer_chart_types:
        for i, completion_type in enumerate(completion_chart_types):
            if not matched_flags[i] and answer_type == completion_type:
                matched_count += 1
                matched_flags[i] = True  # Mark as matched
                break  # Stop after finding one match to avoid matching the same element multiple times
    
    # Calculate score: number of matched chart types / number of chart types in the answer
    score = matched_count / len(answer_chart_types)
    
    return score


if __name__ == "__main__":
    pred = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_c_perpendicular = [0.23, 0.23, 0.9, 0.9, 1.6]\nvo_i = [1.77, np.nan, 2.42, 2.0, np.nan] # Use NaN for missing data\nvo_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nvo_iii = [1.10, np.nan, 1.36, 1.7, np.nan]\n\n# Create indices for plotting\nx_indices = np.arange(len(methods))\n\n# Create figure and subplots\nfig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot Ti barrier energies\naxes[0].plot(x_indices, ti_c_parallel, marker='o', linestyle='-', color='skyblue', label='Ti$_{\\mathrm{i}}$ c$_{\\parallel}$')\naxes[0].plot(x_indices, ti_c_perpendicular, marker='s', linestyle='--', color='salmon', label='Ti$_{\\mathrm{i}}$ c$_{\\perp}$')\n\n# Add values as text labels for Ti\nfor i, txt in enumerate(ti_c_parallel):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_parallel[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_c_perpendicular):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_perpendicular[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[0].set_ylabel('Barrier Energy (eV)')\naxes[0].set_title('Barrier energies of Ti$_{\\mathrm{i}}$ ($c_{\\parallel}$ and $c_{\\perp}$) and V$_{\\mathrm{O}}$ (I, II, and III) paths', fontsize=14)\naxes[0].legend()\naxes[0].grid(True, linestyle='--', alpha=0.6)\n\n# Plot VO barrier energies\naxes[1].plot(x_indices, vo_i, marker='^', linestyle='-', color='lightgreen', label='V$_{\\mathrm{O}}$ I')\naxes[1].plot(x_indices, vo_ii, marker='v', linestyle='--', color='orchid', label='V$_{\\mathrm{O}}$ II')\naxes[1].plot(x_indices, vo_iii, marker='d', linestyle=':', color='gold', label='V$_{\\mathrm{O}}$ III')\n\n# Add values as text labels for VO, handling NaN\nfor i, txt in enumerate(vo_i):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_i[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_ii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_ii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_iii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_iii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[1].set_ylabel('Barrier Energy (eV)')\naxes[1].set_xticks(x_indices)\naxes[1].set_xticklabels(methods, rotation=45, ha='right')\naxes[1].legend()\naxes[1].grid(True, linestyle='--', alpha=0.6)\n\n# Set overall title and adjust layout\nfig.suptitle('Calculated and Experimental Barrier Energies', fontsize=16, y=1.02)\nplt.tight_layout()\n\n# Show the plot\nplt.show()\n\n```"
    gt = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_i_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_i_c_perp = [0.225, 0.23, 0.90, 0.9, 1.6]\nv_o_i = [1.77, np.nan, 2.42, 2.0, np.nan]\nv_o_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nv_o_iii = [1.1, np.nan, 1.36, 1.7, np.nan]\n\n# Define x-axis positions for each method\nx = np.arange(len(methods))\n\n# Create subplots for different barrier types\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot for Ti_i barrier energies\nax1.plot(x, ti_i_c_parallel, label='Ti$_\\mathrm{i}$ $c_\\parallel$', marker='o', linestyle='-', color='skyblue')\nax1.plot(x, ti_i_c_perp, label='Ti$_\\mathrm{i}$ $c_\\perp$', marker='s', linestyle='--', color='salmon')\nax1.set_ylabel('Barrier Energy (eV)')\nax1.set_title('Barrier energies of Ti$_\\mathrm{i}$ ($c_\\parallel$ and $c_\\perp$) and V$_\\mathrm{O}$ (I, II, and III) paths')\nax1.legend(loc='upper left')\nax1.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax1\nfor i, txt in enumerate(ti_i_c_parallel):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i],txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_i_c_perp):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Plot for V_O barrier energies\nax2.plot(x, v_o_i, label='V$_\\mathrm{O}$ I', marker='^', linestyle='-', color='lightgreen')\nax2.plot(x, v_o_ii, label='V$_\\mathrm{O}$ II', marker='v', linestyle='--', color='orchid')\nax2.plot(x, v_o_iii, label='V$_\\mathrm{O}$ III', marker='d', linestyle=':', color='gold')\nax2.set_ylabel('Barrier Energy (eV)')\nax2.set_xticks(x)\nax2.set_xticklabels(methods, rotation=45, ha='right')\nax2.legend(loc='upper left')\nax2.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax2\nfor i, txt in enumerate(v_o_i):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_ii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_iii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Add a title to the entire figure\nfig.suptitle('Barrier Energies of Ti$_\\mathrm{i}$ and V$_\\mathrm{O}$ Paths by Various Methods', fontsize=14, y=1.02)\n\nplt.tight_layout()\nplt.show()\n\n```"
    
    print(f"pred: {pred}")
    print(f"gt: {gt}")

    pred_data = extract_chart_type(pred)
    print(f"pred processed data: {pred_data}")

    gt_data = extract_chart_type(gt)
    print(f"gt processed data: {gt_data}")

    score = compare_chart_type(pred_data, gt_data)
    print(f"score: {score}")
