import re
import json

import pandas as pd


class ResultUtils:
    @staticmethod
    def load_jsonl_dicts(file_paths):
        all_dicts = []
        for file_path in file_paths:
            with open(file_path, 'r') as f:
                for line in f:
                    try:
                        data = json.loads(line)
                        if isinstance(data, dict):
                            all_dicts.append(data)
                        else:
                            print(f"Skipped non-dict line in {file_path}: {line}")
                    except json.JSONDecodeError as e:
                        print(f"Error decoding line in {file_path}: {e}")
        return pd.DataFrame(all_dicts)


    @staticmethod
    def get_latex_table(
                grouped_df: pd.DataFrame,
                caption: str,
                label: str,
                metric_mean: str = 'test_acc_mean(%)',
                metric_std: str = 'test_acc_std(%)',
                bold_best: bool = True,
                underline_best_among_flat_pooling: bool = True,
                greater_is_better: bool = True,
                decimals: int = 2,
                output_tex_file: str = None,
                pooling_method_order: list = None,
            ) -> None:
        # 1) Start from your “grouped” DataFrame:
        #    cols = ['dataset','pooling_method','model_size',
        #            'test_acc_mean(%)','test_acc_std(%)']
        df = grouped_df.copy()

        # 2) Build the raw “Value” string:
        # 2) Build the raw “Value” string with user‐specified precision:
        fmt = f"{{:.{decimals}f}}"
        df['Value'] = (
            "$"
            + df[metric_mean].map(lambda x: fmt.format(x))
            + r" \pm "
            + df[metric_std].map(lambda x: fmt.format(x))
            + "$"
        )

        if bold_best:
            # 3) Locate the best‐mean index per (model_size, dataset):
            if greater_is_better:
                best_idx = df.groupby(
                    ['model_size','dataset']
                )[metric_mean].idxmax()
            else:
                best_idx = df.groupby(
                    ['model_size','dataset']
                )[metric_mean].idxmin()

            # 5) Wrap best in \textbf{…} and second‐best in \underline{…}:
            df.loc[best_idx, 'Value'] = df.loc[best_idx, 'Value'].apply(
                lambda s: r'$\mathbf{' + s.strip('$') + '}$'
            )

        if underline_best_among_flat_pooling:
            # 1) Define which pooling methods you want to underline the “best of”
            subset_methods = ['Last', 'Avg', 'Sum', 'Max']

            # 2) Restrict to just those rows
            subset_df = df[df['pooling_method'].isin(subset_methods)]

            # 3) Find the best‐mean index *within* that subset, per group
            if greater_is_better:
                underline_idx = subset_df.groupby(
                    ['model_size', 'dataset']
                )[metric_mean].idxmax()
            else:
                underline_idx = subset_df.groupby(
                    ['model_size', 'dataset']
                )[metric_mean].idxmin()

            # 4) Wrap those Values in \underline{…}
            df.loc[underline_idx, 'Value'] = df.loc[underline_idx, 'Value'].apply(
                lambda s: r'\underline{' + s + '}'
            )

        model_size_order = ['small','base','large']
        # turn model_size into an ordered categorical
        df['model_size'] = pd.Categorical(
            df['model_size'],
            categories=model_size_order,
            ordered=True
        )

        if pooling_method_order is not None:
            df['pooling_method'] = pd.Categorical(
                df['pooling_method'],
                categories=pooling_method_order,
                ordered=True
            )

        # --- 3) Pivot to MultiIndex rows & columns ---
        df_wide = df.pivot_table(
            index=['model_size','pooling_method'],
            columns=['dataset'],
            values='Value',
            aggfunc='first'    # each cell is unique already
        )

        # Because `model_size` is ordered, a simple sort_index will give the right order:
        order_list = ['model_size']
        if pooling_method_order is not None:
            order_list.append('pooling_method')
        df_wide = df_wide.sort_index(level=order_list)
        
        df_wide.index = pd.MultiIndex.from_tuples([
                (f"\\rotatebox{{90}}{{MOMENT-{i}}}", j) for i, j in df_wide.index
            ])


        # 7) Export to LaTeX exactly as before, **but** with escape=False:
        latex = df_wide.to_latex(
            multirow=True,
            multicolumn=True,
            escape=False,                 # so \textbf{} survives
            index_names=False,
            # caption=(
            # "Average Test accuracy or MSE ($\\pm$ denotes standard deviation), "
            # "on the Base ViT on different computer vision tasks. "
            # "Best performance per dataset and model in \\textbf{bold}."
            # ),
            caption=caption,
            label=label,
            column_format="cc|cccccccccc"
        )

        # 2. Replace all \multirow[t]{...} with \multirow{...}
        latex = re.sub(r'\\multirow\[t\]', r'\\multirow', latex)
        # 8) Wrap in your \renewcommand{\arraystretch} & \resizebox boilerplate:
        full = latex

        if output_tex_file is not None:
            with open(output_tex_file, "w") as f:
                f.write(full)

        return df_wide, full