import os
import numpy as np
import pandas as pd
from utils.parse_xls import parse_xls, TURN_NUMBER
from utils.get_rank import rank_columns_desc

KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3', 'Llama3.1-70B', 'Mixtral-8x22B', \
           'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot' ,'GLM-4-9B', \
           'ERNIE-4', 'Qwen2-7B', 'Qwen2.5-72B', 'Mixtral-8x7B', 'Llama3.1-8B'

def number_mapper(x, col_id):
    if x >= 0:
        return f'{x * 100:.1f}' + r'\%'
    else:
        return f'{-x:.3f}'

def hilight_mapper(x, col_id, rank):
    if col_id == 5:
        return number_mapper(x, col_id)
    if rank == 0:
        return f'\\textbf{{{number_mapper(x, col_id)}}}'
    elif rank == 1:
        # underline
        return f'\\underline{{{number_mapper(x, col_id)}}}'
    else:
        return number_mapper(x, col_id)

row_margins = 2.0, 1.5  # the first, and the rest, in ex

def get_data(key, root_dir='output'):
    res = np.zeros(13)
    cnt = np.zeros(13)
    
    def update_res(base, val):
        res[base] += val
        cnt[base] += 1
    
    try:
        df = parse_xls(key, root_dir=root_dir)
    except Exception as e:
        print(f'Error: {e}, when reading {key}')
        return res
    
    
    for index, row in df.iterrows():
        val = row['是否可用']
        turn = index % TURN_NUMBER
        
        if turn == 0:
            accmulated_turn = 0
        
        assert isinstance(row['multi_rounds_related'], bool), f'Unknown multi_rounds_related value: {row["multi_rounds_related"]}, at index {index} of {key}'
        base = 0 if row['multi_rounds_related'] else 6

        if val > 0 and accmulated_turn == turn:
            accmulated_turn += 1
        
        update_res(base + turn, 1 if accmulated_turn == turn + 1 else 0)
        
        if turn == TURN_NUMBER - 1:
            update_res(base + 5, accmulated_turn)
            update_res(12, accmulated_turn)
    
    res[[5, 11, 12]] /= TURN_NUMBER
    # print(key, res, cnt)
    
    return res / cnt

def hypothesis_testing(x1, x2):
    # test x1 > x2
    from scipy.stats import mannwhitneyu
    stat, p = mannwhitneyu(x1, x2, alternative='greater')
    print(f'Statistics={stat:.3f}, p={p:.3f}')
    if p < 0.05:
        print('Reject the null hypothesis, x1 > x2')
    else:
        print('Fail to reject the null hypothesis (x1 <= x2)')
    return p

def get_linear_slope(y):
    x = np.arange(len(y))
    return np.polyfit(x, y, 1)[0]

BEFORE_TEX = r'''
% !!!!!!!!! set this at beginning of the document !!!!!!!!!
\newcolumntype{M}[1]{>{\centering\arraybackslash}m{#1}}
% !!!!!!!!! set this at beginning of the document !!!!!!!!!

\begin{table*}[t]
\centering
\small
\begin{tabular}{c|M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}|M{0.75cm}|M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}|M{0.75cm}|M{0.75cm}}
    \hline
    \multirow{2}{*}{Model} & \multicolumn{6}{c|}{Multi-turn Dependent} & \multicolumn{6}{c|}{Multi-turn Parallel} & Total\\
        & R1 & R2 & R3 & R4 & R5 & \textbf{SSR} & R1 & R2 & R3 & R4 & R5 & \textbf{SSR} & \textbf{SSR}\\
    \hline\hline'''
AFTER_TEX = r'''
\end{tabular}
    \caption{Tabel 4}
\end{table*}
'''

if __name__ == '__main__':
    data_table = np.zeros((len(KEY_LIST), 13))
    for i, key in enumerate(KEY_LIST):
        data_table[i] = get_data(key)
        
    # find desending order for each column
    # data_ranked = rank_columns_desc(data_table)

    print(BEFORE_TEX)
    print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n")
    print(f"% --- generated by {os.path.basename(__file__)} --- %\n")
    type_list = ['Dependent', 'Parallel']
    n_models = len(KEY_LIST)
    for type_name in type_list:
        type_string = r'\multirow{' + str(n_models) + r'}{*}{' + type_name + r'}'
        data_table_type = np.empty((n_models, 7))
        if type_name == 'Dependent':
            data_table_type[:, 0:5] = data_table[:, 0:5]
            for row_id in range(n_models):
                data_table_type[row_id, 5] = get_linear_slope(data_table[row_id, 0:5])
            data_table_type[:, 6] = data_table[:, 5]
        else:
            data_table_type[:, 0:5] = data_table[:, 6:11]
            for row_id in range(n_models):
                data_table_type[row_id, 5] = get_linear_slope(data_table[row_id, 6:11])
            data_table_type[:, 6] = data_table[:, 12]
            
        # sort by the last column
        index = np.argsort(data_table_type[:, -1])[::-1]
        data_table_type = data_table_type[index]
        label_list = [KEY_LIST[i] for i in index]
        data_ranked_type = rank_columns_desc(data_table_type)
        
        for i, key in enumerate(label_list):
            print(r'\rule{0pt}{' + str(row_margins[1 if i else 0]) + r'ex}')
            print(type_string if i == 0 else '', end=' & ')
            print(f'{key} & ' 
                + ' & '.join([hilight_mapper(x, j, data_ranked_type[i, j]) for j, x in enumerate(data_table_type[i])])
                + ' \\\\')
        if type_name != type_list[-1]:
            print(r'\midrule')
            
        if type_name == 'Dependent':
            x1 = -data_table_type[:, 5]
        else:
            x2 = -data_table_type[:, 5]
    
    print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %")
    print(f"\n% p_value = {hypothesis_testing(x1, x2)}")
    
    print(AFTER_TEX)