import os
import numpy as np
import pandas as pd
from utils.parse_xls import parse_xls
from utils.get_rank import rank_columns_desc

KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3', 'Llama3.1-70B', 'Llama3.1-8B', \
           'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot' ,'GLM-4-9B', \
           'ERNIE-4', 'Qwen2-7B', 'Qwen2.5-72B' , 'Mixtral-8x7B', 'Mixtral-8x22B'
LABEL_MAP = {
    'Action' : '动作约束',
    'Content': '内容约束',
    'Background': '背景约束',
    'Role': '角色约束',
    'Format': '格式约束',
    'Style': '风格约束',
    'Total': 'total'
}

number_mapper = lambda x, col_id: f'{x * 100:.1f}' + r'\%'
def hilight_mapper(x, col_id, rank):
    if rank == 0:
        return f'\\textbf{{{number_mapper(x, col_id)}}}'
    elif rank == 1:
        return f'\\underline{{{number_mapper(x, col_id)}}}'
    else:
        return number_mapper(x, col_id)

row_margins = 2.0, 1.5  # the first, and the rest, in ex

N = len(LABEL_MAP)
def get_data(key, root_dir='output'):
    res = np.zeros(N)
    
    try:
        df = parse_xls(key, sheet_name='不同约束类型遵循', root_dir=root_dir)
    except Exception as e:
        print(f'Error: {e}, when reading {key}')
        return res
    
    for i, (_, col) in enumerate(LABEL_MAP.items()):
        column = df[col]
        res[i] = column[1] / column[0]
    
    return res

BEFORE_TEX = r'''\begin{table*}[htp]
\centering
% \small
\begin{tabular}{|c|cccccc|c|}
    \hline
    \rule{0pt}{2.0ex}
    \multirow{2}{*}{Model} & \multicolumn{7}{c|}{\textbf{CSR}} \\
        & Action & Content & Background & Role & Format & Style & Total \\\hline'''
AFTER_TEX = r'''\hline
\end{tabular}
\caption{Table 6}
\end{table*}'''

if __name__ == '__main__':
    data_table = np.zeros((len(KEY_LIST), N))
    for i, key in enumerate(KEY_LIST):
        data_table[i] = get_data(key)
        
    # sort by the last column
    index = np.argsort(data_table[:, -1])[::-1]
    data_table = data_table[index]
    label_list = [KEY_LIST[i] for i in index]
        
    # find desending order for each column
    data_ranked = rank_columns_desc(data_table)
    print(BEFORE_TEX)
    print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n")
    print(f"% --- generated by {os.path.basename(__file__)} --- %\n")

    for i, key in enumerate(label_list):
        print(r'\rule{0pt}{' + str(row_margins[1 if i else 0]) + r'ex}')
        print(f'{key} & ' 
            + ' & '.join([hilight_mapper(x, j, data_ranked[i, j]) for j, x in enumerate(data_table[i])])
            + ' \\\\')
    # print('\\hline')

    print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %")
    print(AFTER_TEX)