import pandas as pd
import re
import os
import sys
sys.path.append('.')
path0 = os.path.dirname(sys.argv[0])


# ID_UCI, N_CASE
uci_case_list = [
    (165, 5, (160, 800)), 
    (186, 4, (400, 3200)), 
    (291, 5, (160, 800)), 
    (294, 5, (400, 3200)),
    # (464, 2, (400, 3200)),
]


# Model mapping for display
model_display = {
    'HVBLL': 'HVBLL',
    'VBLL': 'VBLL',
    'BLL': 'BLL',
    'MC-Dropout': 'Dropout',
    # 'Deep-GP': 'Deep-GP',    #! Very bad results
    'PNN': 'PNN',
    'SWAG': 'SWAG',
    'DVI': 'DVI',
    'MDN': 'MDN'
}


def write_table(ID_UCI, N_CASES, DATA_SIZE, metric='nll'):

    # Path to the CSV files for different cases
    id_cases = [i for i in range(N_CASES)]
    base_csv_file = os.path.join(path0, 'result', f'comparison-DS{ID_UCI}-C{{}}.csv')

    # Dictionary to store results for each model and dataset ID
    results_by_model = {model: {} for model in model_display.keys()}

    # Process each case (dataset ID)
    for case_id in id_cases:
        csv_file = base_csv_file.format(case_id)
        try:
            df = pd.read_csv(csv_file, index_col=0)
            # Strip whitespace from index to handle any formatting issues
            df.index = df.index.str.strip()
            # print(f"Processing {csv_file}")
            
            # For each model, get the test NLL for the specified DATA_SIZE
            for model in model_display.keys():
                # Look for the specific row matching the model and DATA_SIZE
                row_key = f"{model}_{DATA_SIZE}"
                if row_key in df.index:
                    mean_value = df.loc[row_key, f'test_{metric}_mean']
                    std_value = df.loc[row_key, f'test_{metric}_std']
                    
                    test_metric = f"{mean_value:.2f} \\pm {std_value:.2f}"
                    results_by_model[model][case_id] = test_metric

                else:
                    # Try to find matching row with case-insensitive comparison and ignoring extra spaces
                    potential_keys = [idx for idx in df.index if idx.lower().replace(" ", "") == row_key.lower()]
                    if potential_keys:
                        # Use the first matching key found
                        match_key = potential_keys[0]
                        mean_value = df.loc[match_key, f'test_{metric}_mean']
                        std_value = df.loc[match_key, f'test_{metric}_std']
                        
                        test_metric = f"{mean_value:.2f} \\pm {std_value:.2f}"
                        results_by_model[model][case_id] = test_metric

                    else:
                        print(f"Warning: No data for {row_key} in case {case_id}")
                        results_by_model[model][case_id] = "-"
        except FileNotFoundError:
            print(f"Warning: File {csv_file} not found. Skipping case {case_id}.")
            for model in model_display.keys():
                results_by_model[model][case_id] = "-"

    # Find best (lowest) NLL for each dataset ID to mark as bold
    best_by_case = {}
    for case_id in id_cases:
        best_value = float('inf')
        for model in model_display.keys():
            if case_id in results_by_model[model] and results_by_model[model][case_id] != "-":
                try:
                    # Extract just the mean value for comparison
                    mean_value = float(results_by_model[model][case_id].split(' ')[0])
                    best_value = min(best_value, mean_value)
                except (ValueError, IndexError):
                    pass
        best_by_case[case_id] = best_value

    # Generate LaTeX table
    table_columns = "l " + " ".join(["c"] * N_CASES)
    case_headers = " & ".join([str(i+1) for i in range(N_CASES)])

    latex_table = f'''\\begin{{table}}[htbp]
\\centering
\\caption{{NLL results for UCI regression tasks (UCI ID = {ID_UCI}, $N_\\text{{s}}={int(DATA_SIZE/0.8)}$)}}
\\label{{tab:nll_uci_{ID_UCI}_{int(DATA_SIZE/0.8)}}}
\\begin{{tabular}}{{{table_columns}}}
\\toprule
Dataset ID & {case_headers} \\\\
\\midrule
'''

    # Add each model's data to the table
    for model, display_name in model_display.items():
        row = f"{display_name} "
        for case_id in id_cases:
            if case_id in results_by_model[model]:
                test_metric = results_by_model[model][case_id]
                if test_metric != "-":
                    try:
                        mean_value = float(test_metric.split(' ')[0])
                        # Mark as bold if this is the best result
                        if abs(mean_value - best_by_case[case_id]) < 0.01:  # Using a small epsilon for float comparison
                            row += f"& $\\mathbf{{{test_metric}}}$ "
                        else:
                            row += f"& ${test_metric}$ "
                    except (ValueError, IndexError):
                        row += f"& ${test_metric}$ "
                else:
                    row += "& - "
            else:
                row += "& - "
        latex_table += row + r"\\" + "\n"

    # Complete the table
    latex_table += r'''\bottomrule
\end{tabular}
\end{table}
'''
    
    return latex_table


if __name__ == "__main__":
    
    
    f = open(os.path.join(path0, 'table.tex'), 'w')
    
    
    for ID_UCI, N_CASES, DATA_SIZE_list in uci_case_list:
        
        latex_table_nll = write_table(ID_UCI, N_CASES, DATA_SIZE_list[0], metric='nll')
        
        f.write('\n\n')
        f.write(latex_table_nll)
        
        latex_table_nll = write_table(ID_UCI, N_CASES, DATA_SIZE_list[1], metric='nll')
        
        f.write('\n\n')
        f.write(latex_table_nll)

    f.close()

    print("LaTeX table has been generated and saved to table.tex")
    