# Parse command log of verl without wandb (or offline)

import ast
import yaml
import re
import numpy as np

show_keys = ['acc', 'actor', 'critic', 'response_length', 'val-core', 'val-aux']

def parse_custom_log_string(log_string: str) -> dict:
    """
    Parse a custom log string in a dictionary-like format into a Python dictionary.
    Format: {key1=value1,key2=value2,...}
    """
    # 1. Check if the input is a string and has the correct format
    if not isinstance(log_string, str) or not log_string.startswith('{') or not log_string.endswith('}'):
        return ''
        raise ValueError("Invalid input format. Must start with '{' and end with '}'.")

    # 2. Remove leading and trailing braces
    content_str = log_string.strip('{}')
    
    # 3. Define regex to split key-value pairs
    #    This regex matches commas only if followed by a valid key name.
    #    (?=[a-zA-Z_]) is a positive lookahead to ensure the next char is a letter or underscore.
    #    \s* handles optional whitespace after the comma.
    pattern = r',\s*(?=[a-zA-Z_])'
    
    # Split string intelligently
    pairs = re.split(pattern, content_str)
    
    result_dict = {}
    
    # 4. Traverse each pair and extract key/value
    for pair in pairs:
        # Split only at the first '=' to avoid breaking values that contain '='
        parts = pair.split('=', 1)
        if len(parts) == 2:
            key = parts[0].strip()
            value = parts[1].strip()
            result_dict[key] = value
            
    return result_dict

def process_log(file_path, start_step=0):
    """
    Process a log file and extract step/value pairs for plotting.
    """
    plt_data = {}
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            try:
                data_dict = parse_custom_log_string(line.strip()[1:-1])
                if data_dict != '':
                    pattern = r'\bstep:\s*\d+\b'
                    if re.search(pattern, data_dict['content']):
                        cont = data_dict['content'].split(' step:')[1]
                        data = cont.split(' - ')
                        step = int(data[0]) + start_step
                        for item in data[1:]:
                            name, value = item.split(':')
                            for kk in show_keys:
                                if kk in name:
                                    value = float(value)
                                    if name in plt_data.keys():
                                        plt_data[name]['x'].append(step)
                                        plt_data[name]['y'].append(value)
                                    else:
                                        plt_data[name] = {
                                            'x': [step],
                                            'y': [value],
                                        }
                                    continue
            except:
                continue

    return plt_data

def moving_average(data, window_size):
    """Compute moving average using convolution"""
    window = np.ones(window_size) / window_size
    return np.convolve(data, window, 'valid').tolist()

def increase_line(xdata, ydata, threshold=None):
    """
    Keep only points where y is monotonically increasing,
    optionally skipping values above a threshold.
    """
    mm = -np.inf
    xx = []
    yy = []
    for x, y in zip(xdata, ydata):
        if threshold is not None and y >= threshold:
            continue
        if y > mm:
            mm = y
            xx.append(x)
            yy.append(y)
    return xx, yy

def add_loss_points(x, y, ss=5):
    """
    Fill missing points with the last known y-value (stepwise interpolation).
    """
    if len(x) < 2:
        return x, y
    
    d = x[1] - x[0]
    d = d if d > 0 else ss
    xmin = min(x)
    xmax = max(x)
    xx = list(range(xmin, xmax + 1, d))
    x2y = {}
    for xi, yi in zip(x, y):
        x2y[xi] = yi
    yy = []
    last_y = 0
    for xi in xx:
        if xi in x2y:
            last_y = x2y[xi]
        yy.append(last_y)
    return xx, yy

# print(len(process_log(file_path).keys()))
