import re
import numpy as np
import json
import os
from collections import defaultdict
from tqdm import tqdm


def extract_layout(code):
    """
    Extract layout information from code, including subplot arrangement and figure size
    
    Parameters:
    code (str): code string
    
    Returns:
    dict: dictionary containing subplot arrangement and figure size
    """
    layout = {
        'figsize': None,  # Figure size, format (width, height)
        'subplot_layout': (1, 1)  # Subplot arrangement, format (rows, cols), default is 1x1
    }
    
    # Extract figsize
    figsize_pattern = r'figsize\s*=\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)'
    figsize_match = re.search(figsize_pattern, code)
    if figsize_match:
        width = int(float(figsize_match.group(1)))  # Convert to float first, then to int
        height = int(float(figsize_match.group(2)))
        layout['figsize'] = (width, height)
    
    # Extract subplot arrangement from plt.subplots
    # Form 1: plt.subplots(rows, cols, ...)
    subplots_pattern1 = r'plt\.subplots\s*\(\s*(\d+)\s*,\s*(\d+)'
    subplots_match1 = re.search(subplots_pattern1, code)
    
    # Form 2: plt.subplots(nrows=rows, ncols=cols, ...) or plt.subplots(ncols=cols, nrows=rows, ...)
    subplots_pattern2 = r'plt\.subplots\s*\(.*?nrows\s*=\s*(\d+).*?ncols\s*=\s*(\d+)'
    subplots_match2 = re.search(subplots_pattern2, code)
    
    subplots_pattern3 = r'plt\.subplots\s*\(.*?ncols\s*=\s*(\d+).*?nrows\s*=\s*(\d+)'
    subplots_match3 = re.search(subplots_pattern3, code)
    
    # Extract subplot arrangement from fig.add_subplot
    # Form 1: fig.add_subplot(rows, cols, idx, ...)
    add_subplot_pattern1 = r'fig\.add_subplot\s*\(\s*(\d+)\s*,\s*(\d+)\s*,'
    add_subplot_matches1 = re.findall(add_subplot_pattern1, code)
    
    # Form 2: fig.add_subplot(rcidx, ...), e.g., fig.add_subplot(221)
    add_subplot_pattern2 = r'fig\.add_subplot\s*\(\s*(\d{3,})\s*[,\)]'
    add_subplot_matches2 = re.findall(add_subplot_pattern2, code)
    
    # Process plt.subplots results
    if subplots_match1:
        rows = int(subplots_match1.group(1))
        cols = int(subplots_match1.group(2))
        layout['subplot_layout'] = (rows, cols)
    elif subplots_match2:
        rows = int(subplots_match2.group(1))
        cols = int(subplots_match2.group(2))
        layout['subplot_layout'] = (rows, cols)
    elif subplots_match3:
        cols = int(subplots_match3.group(1))
        rows = int(subplots_match3.group(2))
        layout['subplot_layout'] = (rows, cols)
    # Process fig.add_subplot results
    elif add_subplot_matches1:
        # Use the first match
        rows = int(add_subplot_matches1[0][0])
        cols = int(add_subplot_matches1[0][1])
        layout['subplot_layout'] = (rows, cols)
    elif add_subplot_matches2:
        # Parse format like 221 (2 rows, 2 columns, 1st subplot)
        rcidx = add_subplot_matches2[0]
        if len(rcidx) >= 3:
            rows = int(rcidx[0])
            cols = int(rcidx[1])
            layout['subplot_layout'] = (rows, cols)
    
    return layout


def compare_layout(completion_layout, answer_layout):
    """
    Compare the consistency between two layouts
    
    Parameters:
    completion_layout (dict): layout from model-generated code
    answer_layout (dict): layout from reference answer code
    
    Returns:
    float: consistency score, ranging from 0 to 1
    """
    score = 0.0
    
    # Compare subplot arrangement
    if completion_layout['subplot_layout'] == answer_layout['subplot_layout']:
        score += 0.5
    
    # Compare figure size
    # If reference answer doesn't specify figure size, automatically get 0.5 points
    if answer_layout['figsize'] is None:
        score += 0.5
    # If both specify figure size, compare for consistency
    elif completion_layout['figsize'] is not None and answer_layout['figsize'] is not None:
        if completion_layout['figsize'] == answer_layout['figsize']:
            score += 0.5
    # If reference answer specifies figure size but generated code doesn't, get 0 points
    
    return score


if __name__ == "__main__":
    pred = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_c_perpendicular = [0.23, 0.23, 0.9, 0.9, 1.6]\nvo_i = [1.77, np.nan, 2.42, 2.0, np.nan] # Use NaN for missing data\nvo_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nvo_iii = [1.10, np.nan, 1.36, 1.7, np.nan]\n\n# Create indices for plotting\nx_indices = np.arange(len(methods))\n\n# Create figure and subplots\nfig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot Ti barrier energies\naxes[0].plot(x_indices, ti_c_parallel, marker='o', linestyle='-', color='skyblue', label='Ti$_{\\mathrm{i}}$ c$_{\\parallel}$')\naxes[0].plot(x_indices, ti_c_perpendicular, marker='s', linestyle='--', color='salmon', label='Ti$_{\\mathrm{i}}$ c$_{\\perp}$')\n\n# Add values as text labels for Ti\nfor i, txt in enumerate(ti_c_parallel):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_parallel[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_c_perpendicular):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_perpendicular[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[0].set_ylabel('Barrier Energy (eV)')\naxes[0].set_title('Barrier energies of Ti$_{\\mathrm{i}}$ ($c_{\\parallel}$ and $c_{\\perp}$) and V$_{\\mathrm{O}}$ (I, II, and III) paths', fontsize=14)\naxes[0].legend()\naxes[0].grid(True, linestyle='--', alpha=0.6)\n\n# Plot VO barrier energies\naxes[1].plot(x_indices, vo_i, marker='^', linestyle='-', color='lightgreen', label='V$_{\\mathrm{O}}$ I')\naxes[1].plot(x_indices, vo_ii, marker='v', linestyle='--', color='orchid', label='V$_{\\mathrm{O}}$ II')\naxes[1].plot(x_indices, vo_iii, marker='d', linestyle=':', color='gold', label='V$_{\\mathrm{O}}$ III')\n\n# Add values as text labels for VO, handling NaN\nfor i, txt in enumerate(vo_i):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_i[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_ii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_ii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_iii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_iii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[1].set_ylabel('Barrier Energy (eV)')\naxes[1].set_xticks(x_indices)\naxes[1].set_xticklabels(methods, rotation=45, ha='right')\naxes[1].legend()\naxes[1].grid(True, linestyle='--', alpha=0.6)\n\n# Set overall title and adjust layout\nfig.suptitle('Calculated and Experimental Barrier Energies', fontsize=16, y=1.02)\nplt.tight_layout()\n\n# Show the plot\nplt.show()\n\n```"
    gt = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_i_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_i_c_perp = [0.225, 0.23, 0.90, 0.9, 1.6]\nv_o_i = [1.77, np.nan, 2.42, 2.0, np.nan]\nv_o_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nv_o_iii = [1.1, np.nan, 1.36, 1.7, np.nan]\n\n# Define x-axis positions for each method\nx = np.arange(len(methods))\n\n# Create subplots for different barrier types\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot for Ti_i barrier energies\nax1.plot(x, ti_i_c_parallel, label='Ti$_\\mathrm{i}$ $c_\\parallel$', marker='o', linestyle='-', color='skyblue')\nax1.plot(x, ti_i_c_perp, label='Ti$_\\mathrm{i}$ $c_\\perp$', marker='s', linestyle='--', color='salmon')\nax1.set_ylabel('Barrier Energy (eV)')\nax1.set_title('Barrier energies of Ti$_\\mathrm{i}$ ($c_\\parallel$ and $c_\\perp$) and V$_\\mathrm{O}$ (I, II, and III) paths')\nax1.legend(loc='upper left')\nax1.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax1\nfor i, txt in enumerate(ti_i_c_parallel):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i],txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_i_c_perp):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Plot for V_O barrier energies\nax2.plot(x, v_o_i, label='V$_\\mathrm{O}$ I', marker='^', linestyle='-', color='lightgreen')\nax2.plot(x, v_o_ii, label='V$_\\mathrm{O}$ II', marker='v', linestyle='--', color='orchid')\nax2.plot(x, v_o_iii, label='V$_\\mathrm{O}$ III', marker='d', linestyle=':', color='gold')\nax2.set_ylabel('Barrier Energy (eV)')\nax2.set_xticks(x)\nax2.set_xticklabels(methods, rotation=45, ha='right')\nax2.legend(loc='upper left')\nax2.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax2\nfor i, txt in enumerate(v_o_i):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_ii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_iii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Add a title to the entire figure\nfig.suptitle('Barrier Energies of Ti$_\\mathrm{i}$ and V$_\\mathrm{O}$ Paths by Various Methods', fontsize=14, y=1.02)\n\nplt.tight_layout()\nplt.show()\n\n```"
    
    print(f"pred: {pred}")
    print(f"gt: {gt}")
    
    pred_data = extract_layout(pred)
    print(f"pred processed data: {pred_data}")

    gt_data = extract_layout(gt)
    print(f"gt processed data: {gt_data}")
    
    score = compare_layout(pred_data, gt_data)
    print(f"score: {score}")
