import pandas as pd
import re
import os
import sys
sys.path.append('.')
path0 = os.path.dirname(sys.argv[0])


# Configuration for ERA5 dataset
N_CASES = 3  # Three data sizes: 400, 3200, 16000
DATA_SIZES = [400, 3200, 16000]  # Different data sizes instead of case IDs

metric_list = ['nll', 'mae', 'crps']

# Model mapping for display
model_display = {
    'HVBLL': 'HVBLL',
    'VBLL': 'VBLL',
    'BLL': 'BLL',
    'MC-Dropout': 'Dropout',
    'PNN': 'PNN',
    'SWAG': 'SWAG',
    'DVI': 'DVI',
    'MDN': 'MDN'
}


def write_table(metric='nll'):

    # Path to the single CSV file
    csv_file = os.path.join(path0, 'result', 'comparison-era5.csv')

    # Dictionary to store results for each model and data size
    results_by_model = {model: {} for model in model_display.keys()}

    # Process the single CSV file
    try:
        df = pd.read_csv(csv_file, index_col=0)
        # Strip whitespace from index to handle any formatting issues
        df.index = df.index.str.strip()

        # For each data size, get the test NLL for each model
        for i, data_size in enumerate(DATA_SIZES):
            for model in model_display.keys():
                # Look for the specific row matching the model and data size
                row_key = f"{model}_{data_size}"
                if row_key in df.index:
                    test_metric = f"{df.loc[row_key, f'test_{metric}_mean']:.2f} \\pm {df.loc[row_key, f'test_{metric}_std']:.2f}"
                    results_by_model[model][i] = test_metric
                else:
                    # Try to find matching row with case-insensitive comparison and ignoring extra spaces
                    potential_keys = [idx for idx in df.index if idx.lower().replace(" ", "") == row_key.lower()]
                    if potential_keys:
                        # Use the first matching key found
                        match_key = potential_keys[0]
                        test_metric = f"{df.loc[match_key, f'test_{metric}_mean']:.2f} \\pm {df.loc[match_key, f'test_{metric}_std']:.2f}"
                        results_by_model[model][i] = test_metric
                    else:
                        print(f"Warning: No data for {row_key}")
                        results_by_model[model][i] = "-"
    except FileNotFoundError:
        print(f"Warning: File {csv_file} not found.")
        for model in model_display.keys():
            for i in range(N_CASES):
                results_by_model[model][i] = "-"

    # Find best (lowest) metric for each data size to mark as bold
    best_by_case = {}
    for i in range(N_CASES):
        best_value = float('inf')
        for model in model_display.keys():
            if i in results_by_model[model] and results_by_model[model][i] != "-":
                try:
                    # Extract just the mean value for comparison
                    mean_value = float(results_by_model[model][i].split(' ')[0])
                    if metric == 'nll':
                        best_value = min(best_value, mean_value)
                    elif metric == 'mae':
                        best_value = min(best_value, mean_value)
                    elif metric == 'crps':
                        best_value = min(best_value, mean_value)
                    else:
                        raise ValueError(f"Invalid metric {metric}")
                except (ValueError, IndexError):
                    pass
        best_by_case[i] = best_value

    # Generate LaTeX table
    table_columns = "l " + " ".join(["c"] * N_CASES)
    case_headers = " & ".join([str(data_size) for data_size in DATA_SIZES])

    latex_table = f'''\\begin{{table}}[htbp]
\\centering
\\caption{{{metric.upper()} results for ERA5 regression tasks}}
\\label{{tab:{metric}_era5}}
\\begin{{tabular}}{{{table_columns}}}
\\toprule
Model & {case_headers} \\\\
\\midrule
'''

    # Add each model's data to the table
    for model, display_name in model_display.items():
        row = f"{display_name} "
        for i in range(N_CASES):
            if i in results_by_model[model]:
                test_metric = results_by_model[model][i]
                if test_metric != "-":
                    try:
                        mean_value = float(test_metric.split(' ')[0])
                        # Mark as bold if this is the best result
                        if abs(mean_value - best_by_case[i]) < 0.01:  # Using a small epsilon for float comparison
                            row += f"& $\\mathbf{{{test_metric}}}$ "
                        else:
                            row += f"& ${test_metric}$ "
                    except (ValueError, IndexError):
                        row += f"& ${test_metric}$ "
                else:
                    row += "& - "
            else:
                row += "& - "
        latex_table += row + r"\\" + "\n"

    # Complete the table
    latex_table += r'''\bottomrule
\end{tabular}
\end{table}
'''

    return latex_table


if __name__ == "__main__":
    
    f = open(os.path.join(path0, 'table.tex'), 'w')

    for metric in metric_list:
        
        latex_table = write_table(metric=metric)
    
        f.write('\n\n')
        f.write(latex_table)
        
    f.close()

    print("LaTeX table has been generated and saved to table.tex")
