import re
import numpy as np
import nltk
import json
import os
from collections import defaultdict
from tqdm import tqdm
import math


def remove_inline_comment(line):
    """
    Remove the comment part from a single line
    """
    result = []
    in_string = False
    string_delimiter = None
    
    i = 0
    while i < len(line):
        char = line[i]
        
        # Handle strings
        if char in ['"', "'"]:
            if not in_string:
                in_string = True
                string_delimiter = char
            elif char == string_delimiter:
                in_string = False
            result.append(char)
        
        # Handle comments
        elif char == '#' and not in_string:
            break  # When comment symbol is encountered, stop processing
        
        # Other characters
        else:
            result.append(char)
        
        i += 1
    
    return ''.join(result)


def parse_list_elements_with_type(list_content):
    """
    Parse list content, return elements list with original types
    Handle commas in strings to ensure correct element splitting
    """
    elements = []
    in_string = False
    string_delimiter = None
    current_element = ""
    contains_var_ref = False

    for char in list_content:
        if char in ['"', "'"]:
            if not in_string:
                in_string = True
                string_delimiter = char
            elif char == string_delimiter:
                in_string = False
            current_element += char
        elif char == ',' and not in_string:
            if current_element.strip():
                # Parse element type
                element_value, is_var = parse_value_type(current_element.strip())
                if is_var:
                    contains_var_ref = True
                elements.append(element_value)
            current_element = ""
        else:
            current_element += char
    
    if current_element.strip():
        # Parse the last element type
        element_value, is_var = parse_value_type(current_element.strip())
        if is_var:
            contains_var_ref = True
        elements.append(element_value)

    return elements, contains_var_ref


def parse_value_type(value_str):
    """
    Try to parse the type of value, keeping the original type
    Remove string prefixes (like r) and quotes, keep only the actual content
    """
    value_str = value_str.strip()

    # Handle np.nan values
    if value_str == 'np.nan':
        return float('nan'), False
    
    # Handle strings with prefixes, like r'text' or f"text"
    prefixes = ['r', 'f', 'u', 'b', 'R', 'F', 'U', 'B']
    for prefix in prefixes:
        if value_str.startswith(prefix) and len(value_str) > 1:
            if value_str[len(prefix):].startswith('"') and value_str.endswith('"'):
                return value_str[len(prefix)+1:-1], False
            elif value_str[len(prefix):].startswith("'") and value_str.endswith("'"):
                return value_str[len(prefix)+1:-1], False
    
    # Try to parse as number
    try:
        # Try to parse as integer
        if value_str.isdigit() or (value_str.startswith('-') and value_str[1:].isdigit()):
            return int(value_str), False
        
        # Try to parse as float
        return float(value_str), False
    except (ValueError, TypeError):
        if (value_str.startswith('"') and value_str.endswith('"')) or \
        (value_str.startswith("'") and value_str.endswith("'")):
            return value_str[1:-1], False

        # If parsing fails, return the original string
        return value_str, True


def extract_values(code_content):
    """
    Extract numerical lists from code
    
    Parameters:
    code (str): code string
    
    Returns:
    extracted_data (dict): extracted data structure
    """
    # Split code by line
    lines = code_content.split('\n')
    # Process each line, while looking for lines containing plt.subplots() or plt.figure()
    cleaned_lines = []
    plot_line_index = -1

    for i, line in enumerate(lines):
        # Handle inline comments
        cleaned_line = remove_inline_comment(line)
        
        # If a line containing plt.subplots() or plt.figure() is found, record the index and stop adding lines
        if '=' in cleaned_line and ('plt.subplots' in cleaned_line or 'plt.figure' in cleaned_line):
            plot_line_index = i
            break
        
        # If the line is not empty, add it to the cleaned lines list
        if cleaned_line.strip():
            cleaned_lines.append(cleaned_line)
    
    if plot_line_index == -1:  # If no relevant line is found, filter directly
        return {}

    # Recombine into code without comments
    cleaned_code = '\n'.join(cleaned_lines)
    # print(f"Code after removing comments: {cleaned_code}")

    # Store extracted data
    extracted_data = {
        'lists': {},
        'dicts': {}
    }

    # Find all variable assignment statements
    assignment_pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+?)(?=\n[a-zA-Z_][a-zA-Z0-9_]*\s*=|\n\s*$|\Z)'
    assignments = re.findall(assignment_pattern, cleaned_code, re.DOTALL)

    # Process each assignment statement
    for var_name, var_content in assignments:  
        var_name = var_name.strip()
        var_content = var_content.strip()

        # Check if it's a numpy array
        if var_content.startswith('np.array(') and var_content.endswith(')'):
            var_content = var_content[9:-1].strip()

        # Check if it's a dictionary definition
        if var_content.startswith('{') and var_content.endswith('}'):
            # This is a dictionary definition
            dict_content = var_content[1:-1].strip()

            # If the dictionary content contains braces, consider it a nested dictionary
            # Simple solution for nesting, doesn't consider k/v containing {} and [], not considered for now
            if '{' in dict_content or '}' in dict_content:
                continue  # Skip nested dictionaries

            # Extract dictionary keys and values
            keys = []
            values = []
            valid_dict = True  # Flag indicating whether the dictionary is valid

            try:
                # Use regex to extract key-value pairs
                key_value_pattern = r'(["\'][^"\']+["\']|\w+)\s*:\s*(.*?)(?=,\s*(?:["\'][^"\']+["\']|\w+)\s*:|$)'
                kv_matches = re.findall(key_value_pattern, dict_content, re.DOTALL)

                for key, value in kv_matches:
                    # Parse key, check if it's a variable reference
                    key_value, is_key_var = parse_value_type(key.strip())
                    if is_key_var:
                        valid_dict = False  # Key contains variable reference, mark as invalid
                        break

                    value_str = value.strip()

                    # Check if value is a list, and if it's a nested list
                    if value_str.startswith('[') and value_str.endswith(']'):
                        list_content = value_str[1:-1].strip()
                        
                        # If list content contains brackets, consider it a nested list
                        if '[' in list_content or ']' in list_content:
                            valid_dict = False  # Nested list, mark as invalid
                            break

                        # Parse list elements
                        list_elements, is_var_list = parse_list_elements_with_type(list_content)
                        if is_var_list:
                            valid_dict = False  # List contains variable reference, mark as invalid
                            break

                        keys.append(key_value)
                        values.append(list_elements)
                    else:
                        # Parse value, check if it's a variable reference
                        value_parsed, is_value_var = parse_value_type(value_str)
                        if is_value_var:
                            valid_dict = False  # Value contains variable reference, mark as invalid
                            break
                        
                        keys.append(key_value)
                        values.append(value_parsed)
                
                # Save dictionary data
                if valid_dict and keys and values:
                    extracted_data['dicts'][var_name] = {'keys': keys, 'values': values}
            except Exception:
                # Parsing error, skip this dictionary
                continue

        # Check if it's a list definition
        elif var_content.startswith('[') and var_content.endswith(']'):
            # This is a list definition
            list_content = var_content[1:-1].strip()
            
            # If list content contains brackets, consider it a nested list
            if '[' in list_content or ']' in list_content:
                continue  # Skip nested lists

            # Parse list elements
            try:
                list_elements, is_var_list = parse_list_elements_with_type(list_content)
                if is_var_list:
                    continue  # List contains variable reference, skip this variable
                
                # Save list data
                extracted_data['lists'][var_name] = list_elements
            except Exception as e:
                # Parsing error, skip this list
                continue

    return extracted_data


def convert_to_lists(values_dict):
    """
    Convert data structure to a list of lists, without keeping variable names and key names
    
    Parameters:
    values_dict (dict): Data structure containing 'lists' and 'dicts'
    
    Returns:
    list: List of lists, each inner list contains a set of values
    """
    result = []
    
    # Process lists part
    if 'lists' in values_dict and values_dict['lists']:
        for var_name, value_list in values_dict['lists'].items():
            result.append(value_list)
    
    # Process dicts part
    if 'dicts' in values_dict and values_dict['dicts']:
        for var_name, dict_data in values_dict['dicts'].items():
            # Add keys list
            if 'keys' in dict_data and dict_data['keys']:
                result.append(dict_data['keys'])
            
            # Add values list
            if 'values' in dict_data and dict_data['values']:
                values = dict_data['values']
                
                # Check if values is a 2D list
                is_2d = any(isinstance(item, list) for item in values)
                
                if is_2d:
                    # If it's a 2D list, iterate through each sublist and add
                    for sublist in values:
                        if isinstance(sublist, list):
                            result.append(sublist)
                        else:
                            # If an element is not a list, add it as a single-element list
                            result.append([sublist])
                else:
                    # If it's a 1D list, add directly
                    result.append(values)
    
    return result

def compare_list(list1, list2):
    """
    Compare the similarity of two lists
    Return similarity score, ranging from 0 to 1
    """
    # If lists have different lengths, apply penalty
    length_ratio = min(len(list1), len(list2)) / max(len(list1), len(list2)) if max(len(list1), len(list2)) > 0 else 1.0
    
    # Calculate element matching scores
    element_scores = []
    for i in range(min(len(list1), len(list2))):
        element_scores.append(compare_element(list1[i], list2[i]))
    
    # Calculate average element matching score
    avg_element_score = sum(element_scores) / len(element_scores) if element_scores else 0.0
    
    # Consider both length ratio and element matching score
    return length_ratio * avg_element_score

def compare_element(elem1, elem2):
    """
    Compare the similarity of two elements
    Return similarity score, ranging from 0 to 1
    """
    # If both elements are numeric types (int or float)
    if isinstance(elem1, (int, float)) and isinstance(elem2, (int, float)):
        return relaxed_accuracy(elem1, elem2)
    
    # If both elements are strings
    elif isinstance(elem1, str) and isinstance(elem2, str):
        return compare_text(elem1, elem2)
    
    # If types don't match
    else:
        return 0.0

def relaxed_accuracy(prediction, target, max_relative_change=0.05):
    """
    Compare if two numerical values are equal within a given tolerance range
    """
    # Handle special case of NaN values
    if math.isnan(prediction) and math.isnan(target):
        return 1.0 
    if math.isnan(prediction) or math.isnan(target):
        return 0.0

    # Calculate relative error
    relative_change = abs(prediction - target) / (abs(target) if target != 0 else 1.0)
    # Convert relative error to score: error of 0 gets score 1, error of max_relative_change gets score 0, linear interpolation
    score = max(0.0, 1.0 - (relative_change / max_relative_change))
    return min(score, 1.0)

def compare_text(text1, text2):
    """
    Compare the similarity of two texts using edit distance
    Return similarity score, ranging from 0 to 1
    """
    if not text1 and not text2:
        return 1.0
    if not text1 or not text2:
        return 0.0
    
    max_len = max(len(text1), len(text2))
    if max_len == 0:
        return 1.0
    
    edit_dist = nltk.edit_distance(text1, text2)
    similarity = 1.0 - (edit_dist / max_len)
    return max(0.0, similarity)


def compare_values(completion_values, answer_values):
    """
    Compare the consistency of two sets of values
    
    Parameters:
    completion_values (dict): Values from model-generated code
    answer_values (dict): Values from reference answer code
    
    Returns:
    float: Consistency score, ranging from 0 to 1
    """
    # Check if there is any data
    completion_has_data = (completion_values.get('lists') or completion_values.get('dicts'))
    answer_has_data = (answer_values.get('lists') or answer_values.get('dicts'))
    
    # If neither has data, return 1.0 (perfect match)
    if not completion_has_data and not answer_has_data:
        return 1.0
    
    # If one has data but the other doesn't, return 0.0 (complete mismatch)
    if not completion_has_data or not answer_has_data:
        return 0.0

    # Convert all data to list format
    completion_lists = convert_to_lists(completion_values)
    answer_lists = convert_to_lists(answer_values)
    # print(f"completion_lists: {completion_lists}")
    # print(f"answer_lists: {answer_lists}")

    # If no lists were extracted, return 0.0
    if not completion_lists or not answer_lists:
        return 0.0
    
    # Calculate best match score for each answer list with completion lists
    total_score = 0.0
    remaining_completion_lists = completion_lists.copy()
    
    for answer_list in answer_lists:
        if not remaining_completion_lists:  # If there are no remaining completion lists to compare
            break

        # Calculate scores for current answer list with all remaining completion lists
        best_score = 0.0
        best_match_index = -1

        for i, completion_list in enumerate(remaining_completion_lists):
            score = compare_list(completion_list, answer_list)
            if score > best_score:
                best_score = score
                best_match_index = i
        
        # Add best match score, and remove that match from remaining lists
        total_score += best_score
        if best_match_index >= 0:
            remaining_completion_lists.pop(best_match_index)
    
    # Calculate average score
    avg_score = total_score / len(answer_lists) if answer_lists else 0.0
    
    return avg_score


if __name__ == "__main__":
    pred = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_c_perpendicular = [0.23, 0.23, 0.9, 0.9, 1.6]\nvo_i = [1.77, np.nan, 2.42, 2.0, np.nan] # Use NaN for missing data\nvo_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nvo_iii = [1.10, np.nan, 1.36, 1.7, np.nan]\n\n# Create indices for plotting\nx_indices = np.arange(len(methods))\n\n# Create figure and subplots\nfig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot Ti barrier energies\naxes[0].plot(x_indices, ti_c_parallel, marker='o', linestyle='-', color='skyblue', label='Ti$_{\\mathrm{i}}$ c$_{\\parallel}$')\naxes[0].plot(x_indices, ti_c_perpendicular, marker='s', linestyle='--', color='salmon', label='Ti$_{\\mathrm{i}}$ c$_{\\perp}$')\n\n# Add values as text labels for Ti\nfor i, txt in enumerate(ti_c_parallel):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_parallel[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_c_perpendicular):\n    axes[0].annotate(f'{txt:.2f}', (x_indices[i], ti_c_perpendicular[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[0].set_ylabel('Barrier Energy (eV)')\naxes[0].set_title('Barrier energies of Ti$_{\\mathrm{i}}$ ($c_{\\parallel}$ and $c_{\\perp}$) and V$_{\\mathrm{O}}$ (I, II, and III) paths', fontsize=14)\naxes[0].legend()\naxes[0].grid(True, linestyle='--', alpha=0.6)\n\n# Plot VO barrier energies\naxes[1].plot(x_indices, vo_i, marker='^', linestyle='-', color='lightgreen', label='V$_{\\mathrm{O}}$ I')\naxes[1].plot(x_indices, vo_ii, marker='v', linestyle='--', color='orchid', label='V$_{\\mathrm{O}}$ II')\naxes[1].plot(x_indices, vo_iii, marker='d', linestyle=':', color='gold', label='V$_{\\mathrm{O}}$ III')\n\n# Add values as text labels for VO, handling NaN\nfor i, txt in enumerate(vo_i):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_i[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_ii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_ii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(vo_iii):\n    if not np.isnan(txt):\n        axes[1].annotate(f'{txt:.2f}', (x_indices[i], vo_iii[i]), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n\naxes[1].set_ylabel('Barrier Energy (eV)')\naxes[1].set_xticks(x_indices)\naxes[1].set_xticklabels(methods, rotation=45, ha='right')\naxes[1].legend()\naxes[1].grid(True, linestyle='--', alpha=0.6)\n\n# Set overall title and adjust layout\nfig.suptitle('Calculated and Experimental Barrier Energies', fontsize=16, y=1.02)\nplt.tight_layout()\n\n# Show the plot\nplt.show()\n\n```"
    gt = "```python\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data from the table\nmethods = ['GGA-PW91 (2007)', 'GGA-PW91 (2010)', 'LDA+U', 'DMC', 'DMC (neutral)']\nti_i_c_parallel = [0.37, 0.31, 0.54, 0.4, 2.6]\nti_i_c_perp = [0.225, 0.23, 0.90, 0.9, 1.6]\nv_o_i = [1.77, np.nan, 2.42, 2.0, np.nan]\nv_o_ii = [0.69, np.nan, 1.60, 0.9, np.nan]\nv_o_iii = [1.1, np.nan, 1.36, 1.7, np.nan]\n\n# Define x-axis positions for each method\nx = np.arange(len(methods))\n\n# Create subplots for different barrier types\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)\n\n# Plot for Ti_i barrier energies\nax1.plot(x, ti_i_c_parallel, label='Ti$_\\mathrm{i}$ $c_\\parallel$', marker='o', linestyle='-', color='skyblue')\nax1.plot(x, ti_i_c_perp, label='Ti$_\\mathrm{i}$ $c_\\perp$', marker='s', linestyle='--', color='salmon')\nax1.set_ylabel('Barrier Energy (eV)')\nax1.set_title('Barrier energies of Ti$_\\mathrm{i}$ ($c_\\parallel$ and $c_\\perp$) and V$_\\mathrm{O}$ (I, II, and III) paths')\nax1.legend(loc='upper left')\nax1.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax1\nfor i, txt in enumerate(ti_i_c_parallel):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i],txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(ti_i_c_perp):\n    if not np.isnan(txt):\n        ax1.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Plot for V_O barrier energies\nax2.plot(x, v_o_i, label='V$_\\mathrm{O}$ I', marker='^', linestyle='-', color='lightgreen')\nax2.plot(x, v_o_ii, label='V$_\\mathrm{O}$ II', marker='v', linestyle='--', color='orchid')\nax2.plot(x, v_o_iii, label='V$_\\mathrm{O}$ III', marker='d', linestyle=':', color='gold')\nax2.set_ylabel('Barrier Energy (eV)')\nax2.set_xticks(x)\nax2.set_xticklabels(methods, rotation=45, ha='right')\nax2.legend(loc='upper left')\nax2.grid(True, linestyle='--', alpha=0.6)\n\n# Add data labels to ax2\nfor i, txt in enumerate(v_o_i):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_ii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\nfor i, txt in enumerate(v_o_iii):\n    if not np.isnan(txt):\n        ax2.annotate(f'{txt:.2f}', (x[i], txt), textcoords=\"offset points\", xytext=(0,5), ha='center')\n\n# Add a title to the entire figure\nfig.suptitle('Barrier Energies of Ti$_\\mathrm{i}$ and V$_\\mathrm{O}$ Paths by Various Methods', fontsize=14, y=1.02)\n\nplt.tight_layout()\nplt.show()\n\n```"
    
    pred_data = extract_values(pred)
    print(f"pred processed data: {pred_data}")

    gt_data = extract_values(gt)
    print(f"gt processed data: {gt_data}")
    
    score = compare_values(pred_data, gt_data)
    print(f"score: {score}")
