import re
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt

from utils.parse_xls import parse_xls, TURN_NUMBER

LABEL_MAP = {
    'Action' : '动作约束',
    'Content': '内容约束',
    'Background': '背景约束',
    'Role': '角色约束',
    'Format': '格式约束',
    'Style': '风格约束',
}

sector_labels = list(LABEL_MAP.keys())
# sector_colors = ['#ff9999','#66b3ff','#99ff99','#ffcc99','#c2c2f0','#ffb3e6']
sector_colors = ['#b7cef2', '#b9ecea', '#f2f2ca', '#f2ddb6', '#eec1c1', '#d2b6e2']

pattern = r'\d+\.\s(..约束)'

def get_data_old(key):
    res = np.zeros(len(LABEL_MAP), dtype=int)
    
    try:
        df = parse_xls(key, sheet_name='不同约束类型遵循')
    except Exception as e:
        print(f'Error: {e}, when reading {key}')
        return res
    
    for i, (_, col) in enumerate(LABEL_MAP.items()):
        res[i] = df[col][0]
    
    return res

def get_data(key):
    res = np.zeros(len(LABEL_MAP), dtype=int)
    
    try:
        df = parse_xls(key)
    except Exception as e:
        print(f'Error: {e}, when reading {key}')
        return res
    
    for index, row in df.iterrows():
        text = row['评判结果']
        turn = index % TURN_NUMBER
        
        if turn == 0:
            constraints = set()
        
        # get all constraints
        constraints.update(re.findall(pattern, text))
        
        if turn == TURN_NUMBER - 1:
            # print('session:', index//TURN_NUMBER, ', constraints:', constraints)
            for i, col in enumerate(LABEL_MAP.values()):
                if col in constraints:
                    res[i] += 1
    # print('res:', res)
    return res


to_pdf = True

def plot_barh(ax, fontsize=14, height=0.9, stat_type='session'):
    if stat_type == 'base':
        data = get_data_old('GPT-4o')
    elif stat_type == 'session':
        data = get_data('GPT-4o')
    else:
        raise ValueError(f'Invalid stat_type: {stat_type}')
    
    data= data.astype(np.float64) / np.sum(data)
    # sort the indices by the values
    indices = np.argsort(data)
    
    for i, idx in enumerate(indices):
        ax.barh(i, data[idx], height=height, color=sector_colors[idx])
        ax.text(data[idx], i, f"{data[idx]*100:.0f}%", ha='left', va='center', fontsize=fontsize)
    
    ax.set_yticks(np.arange(len(LABEL_MAP)))
    ax.set_yticklabels([sector_labels[i] for i in indices], fontsize=fontsize)
    ax.set_xticks([])
    
    # hide axis expect the ticks on y-axis
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

if __name__ == '__main__':
    plt.rcParams['font.family'] = 'Calibri'
    mpl.rcParams.update({'font.size': 13})

    fig, ax = plt.subplots(1, 1, figsize=(5, 2.5), dpi=300, tight_layout=True)
    plot_barh(ax)
    
    plt.savefig('figures/fig_constraint' + ('.pdf' if to_pdf else '.png'), bbox_inches='tight', pad_inches=0.1)