import os
import numpy as np

from tab4_turn import get_data as get_data_ssr
from tab3_align import get_data as get_data_isr
from tab6_csr_full import get_data as get_data_csr

from utils.parse_xls import get_full_name
from utils.get_rank import rank_columns_desc

KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'Claude-3', 'Llama3.1-70B', 'Llama3.1-8B', \
           'Mixtral-8x22B', 'GPT-3.5', 'Mixtral-8x7B', \
           'Qwen2.5-72B', 'GLM-4', 'Qwen2-72B', 'DeepSeek-V2', 'Moonshot', \
           'GLM-4-9B', 'ERNIE-4', 'Qwen2-7B'
N_ENG_MODEL = 8

number_mapper = lambda x, col_id: f'{x * 100:.1f}' + r'\%'
def hilight_mapper(x, col_id, rank):
    if rank == 0:
        return f'\\textbf{{{number_mapper(x, col_id)}}}'
    elif rank == 1:
        return f'\\underline{{{number_mapper(x, col_id)}}}'
    else:
        return number_mapper(x, col_id)

row_margins = 2.0, 1.5  # the first, and the rest, in ex

def generate_table():
    data_entries = []
    
    for key in KEY_LIST:
        data_ssr = get_data_ssr(key)
        data_isr = get_data_isr(key)
        data_csr = get_data_csr(key)
        
        entry = np.array([data_ssr[-1], data_isr[-1], data_csr[-1]])[::-1]

        data_entries.append(entry)
        # print('Written:', key)
    
    data_table = np.array(data_entries)
    return data_table

BEFORE_TEX=r'''\begin{table}[t]
\centering
\small
\begin{tabular}{c|ccc}
    \toprule
    Full Model Name & \textbf{CSR} & \textbf{ISR} & \textbf{SSR} \\
    \midrule
'''
AFTER_TEX=r'''\bottomrule
\end{tabular}
\caption{Table 2}
\end{table}
'''

if __name__ == '__main__':
    data_table = generate_table()
        
    # find desending order for each column
    data_ranked = rank_columns_desc(data_table)

    # print(BEFORE_TEX)
    print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n")
    print(f"% --- generated by {os.path.basename(__file__)} --- %\n")

    i = 0
    while True:
        eng_i, chn_i = i, i + N_ENG_MODEL
        if eng_i >= N_ENG_MODEL and chn_i >= len(KEY_LIST):
            break
        
        print(r'\rule{0pt}{' + str(row_margins[0 if i == 0 else 1]) + r'ex}')
        
        if eng_i < N_ENG_MODEL:
            print(f'{get_full_name(KEY_LIST[eng_i])} & ' 
                + ' & '.join([hilight_mapper(x, j, data_ranked[eng_i, j]) for j, x in enumerate(data_table[eng_i])])
                + ' & ', end='')
        else:
            print(' & ' * 4, end='')
        
        if chn_i < len(KEY_LIST):
            print(f'{get_full_name(KEY_LIST[chn_i])} & ' 
                + ' & '.join([hilight_mapper(x, j, data_ranked[chn_i, j]) for j, x in enumerate(data_table[chn_i])])
                + ' \\\\')
        else:
            print(' & ' * 4 + r'\\')
        
        i += 1

    print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %")
    # print(AFTER_TEX)