from metagen.benchmarks import *
import pandas as pd
import numpy as np

def make_pivot(df):
    df = df[df.metric.isin(['Valid', 'IoU', 'Chamfer Distance', 'Average Normalized Error'])].copy()

    # Define mappings
    category_mapping = {
        'inverse_design': 'Inverse Design',
        'material_understanding': 'Material Understanding',
        'reconstruction': 'Reconstruction'
    }

    metric_mapping = {
        'Average Normalized Error': 'Error',
        'Chamfer Distance': 'CD'
    }

    # Apply mappings
    df['Category'] = df['Category'].replace(category_mapping)
    df['metric'] = df['metric'].replace(metric_mapping)

    # Rename the column 'metric' to 'Metric'
    df = df.rename(columns={'metric': 'Metric'})

    pivot_table = pd.pivot_table(
    df,
    index='Model',
    columns=['Category', 'Metric'],
    values='value',
    aggfunc='mean'
    )
    return pivot_table

# Define which metrics should be minimized or maximized
minimize_metrics = {'Error', 'CD'}
maximize_metrics = {'IoU', 'Valid'}

# Function to bold best values
def highlight_best(df):
    formatted_df = df.copy()

    for col in df.columns:
        values = df[col]
        metric_name = col[1]  # assuming MultiIndex columns: (Category, Metric)

        # Determine best value (ignoring NaNs)
        if metric_name in minimize_metrics:
            best_val = values.min(skipna=True)
        elif metric_name in maximize_metrics:
            best_val = values.max(skipna=True)
        else:
            continue

        def format_value(x):
            if pd.isna(x):
                return r'\textemdash{}'  # Represent NaNs as a dash
            if metric_name == 'Valid':
                percent = x * 100
                formatted = "100\\%" if percent == 100 else f"{percent:.1f}\\%"
            else:
                formatted = f"{x:.3f}"
            return f"\\textbf{{{formatted}}}" if x == best_val else formatted

        formatted_df[col] = values.apply(format_value)

    return formatted_df

def to_latex(formatted_pivot):
    n_cols = formatted_pivot.shape[1] + 1
    col_fmt = 'c' * n_cols

    latex = formatted_pivot.to_latex(
        escape=False,
        multicolumn=True,
        multicolumn_format='c',
        column_format=col_fmt
    )

    return latex

def make_category_pivot(df, category):
    df = df[df.metric.isin(['Valid', 'IoU', 'Chamfer Distance', 'Average Normalized Error'])].copy()

    # Define mappings
    category_mapping = {
        'inverse_design': 'Inverse Design',
        'material_understanding': 'Material Understanding',
        'reconstruction': 'Reconstruction'
    }

    metric_mapping = {
        'Average Normalized Error': 'Error',
        'Chamfer Distance': 'CD'
    }

    task_mapping = {
        'multiview_and_code_material_understanding': '4 View + Code',
        'single_view_material_understanding': '1 View',
        '4_target_inverse_design': '4 Target',
        '2_view_reconstruction': '2 View',
        '6_target_inverse_design': '6 Target',
        '3_view_reconstruction': '3 View',
        '1_view_reconstruction': '1 View',
        '5_target_inverse_design': '5 Target',
        '3_target_inverse_design': '3 Target',
        '2_target_inverse_design': '2 Target',
        '4_view_reconstruction': '4 View',
        '1_target_inverse_design': '1 Target'
    }

    df = df[df.Category == category].copy()

    # Apply mappings
    df['Category'] = df['Category'].replace(category_mapping)
    df['metric'] = df['metric'].replace(metric_mapping)
    df['Task'] = df['Task'].replace(task_mapping)

    # Rename the column 'metric' to 'Metric'
    df = df.rename(columns={'metric': 'Metric'})

    pivot_table = pd.pivot_table(
    df,
    index='Model',
    columns=['Task', 'Metric'],
    values='value',
    aggfunc='mean'
    )
    return pivot_table


def main():
    llava_omnitask = {
        'test_path': '/benchmark/omnitask/test.jsonl',
        'prediction_path': '/workspace/benchmark_inference/omni/evaluate_test_18500_sysp_omni_test/predicted_code',
        'processed_path': '/workspace/benchmark_inference/omni/evaluate_test_18500_sysp_omni_test/predicted_successes',
        'model_name': 'LLaVAOmniTask',
        'index_at_one': True
    }

    # reconstruction task
    llava_reconstruction = {
        'test_path': '/benchmark/reconstruction/4_view_reconstruction/test.jsonl',
        'prediction_path': '/workspace/benchmark_inference/reconstruction/llava/evaluate_test_6500_raw_sysp_4view_newdata/predicted_code',
        'processed_path': '/workspace/benchmark_inference/reconstruction/llava/evaluate_test_6500_raw_sysp_4view_newdata/predicted_successes',
        'model_name': 'LLaVASingleTask',
        'index_at_one': True
    }

    # inverse design task
    llava_inverse_design = {
        'test_path': '/benchmark/inverse_design/4_target_inverse_design/test.jsonl',
        'prediction_path': '/workspace/benchmark_inference/inverse_design/llava/evaluate_test_9000_raw_sysp_4view/predicted_code',
        'processed_path': '/workspace/benchmark_inference/inverse_design/llava/evaluate_test_9000_raw_sysp_4view/predicted_successes',
        'model_name': 'LLaVASingleTask',
        'index_at_one': True
    }

    # material understanding task
    llava_material_understanding = {
        'test_path': '/benchmark/material_understanding/multiview_and_code_material_understanding/test.jsonl',
        'prediction_path': '/workspace/benchmark_inference/material_understanding/llava/evaluate_test_7000_sysp_raw_4view/predicted_code',
        'processed_path': None, # no need to process
        'model_name': 'LLaVASingleTask',
        'index_at_one': True
    }

    nova_omnitask = {
        'test_path':'/benchmark/omnitask/test.jsonl',
        'prediction_path':'/workspace/inference_data/NovaOmniTask/omnitask/predicted_code',
        'processed_path':'/workspace/inference_data/NovaOmniTask/omnitask/predicted_successes',
        'model_name':'NovaOmniTask',
        'index_at_one':False
    }
    novalite_omnitask = {
        'test_path':'/benchmark/omnitask/test.jsonl',
        'prediction_path':'/workspace/inference_data/NovaLite/omnitask/predicted_code',
        'processed_path':'/workspace/inference_data/NovaLite/omnitask/predicted_successes',
        'model_name':'NovaLite',
        'index_at_one':False
    }
    o3_omnitask = {
        'test_path':'/benchmark/omnitask/test.jsonl',
        'prediction_path':'/workspace/inference_data/OpenAIO3/omnitask/predicted_code',
        'processed_path':'/workspace/inference_data/OpenAIO3/omnitask/predicted_successes',
        'model_name':'OpenAIO3',
        'index_at_one':False
    }
    nova_inverse_design = {
        'test_path':'/benchmark/inverse_design/4_target_inverse_design/test.jsonl',
        'prediction_path':'/workspace/inference_data/NovaSingleTask/inverse_design/4_target_inverse_design/predicted_code',
        'processed_path':'/workspace/inference_data/NovaSingleTask/inverse_design/4_target_inverse_design/predicted_successes',
        'model_name':'NovaSingleTask',
        'index_at_one':False
    }
    nova_material_understanding = {
        'test_path':'/benchmark/material_understanding/multiview_and_code_material_understanding/test.jsonl',
        'prediction_path':'/workspace/inference_data/NovaSingleTask/material_understanding/multiview_and_code_material_understanding/predicted_code',
        'processed_path':'/workspace/inference_data/NovaSingleTask/material_understanding/multiview_and_code_material_understanding/predicted_successes',
        'model_name':'NovaSingleTask',
        'index_at_one':False
    }
    nova_reconstruction = {
        'test_path':'/benchmark/reconstruction/4_view_reconstruction/test.jsonl',
        'prediction_path':'/workspace/inference_data/NovaSingleTask/reconstruction/4_view_reconstruction/predicted_code',
        'processed_path':'/workspace/inference_data/NovaSingleTask/reconstruction/4_view_reconstruction/predicted_successes',
        'model_name':'NovaSingleTask',
        'index_at_one':False
    }

    model_results = [
        o3_omnitask,
        nova_omnitask,
        novalite_omnitask,
        nova_inverse_design,
        nova_material_understanding,
        nova_reconstruction,
        llava_omnitask,
        llava_reconstruction,
        llava_inverse_design,
        llava_material_understanding
    ]
    db = Database('/data/metagen-data/v3/')
    # Saving to 'unfiltered' parquet file now because NovaLite has some ill-behaved materials
    # that standard validation probably should catch but does not.
    results_data = evaluate_benchmarks(db, '/workspace/benchmark_results.parquet', model_results)

    summary_table = highlight_best(make_pivot(results_data))
    reconstruction_table = highlight_best(make_category_pivot(results_data, 'reconstruction'))
    inverse_design_table = highlight_best(make_category_pivot(results_data, 'inverse_design'))
    material_understanding_table = highlight_best(make_category_pivot(results_data, 'material_understanding'))

    print("\nSummary Table:\n")
    print(to_latex(summary_table))
    print("\nReconstruction Table:\n")
    print(to_latex(reconstruction_table))
    print("\nInverse Design Table:\n")
    print(to_latex(inverse_design_table))
    print("\nMaterial Understanding Table:\n")
    print(to_latex(material_understanding_table))

if __name__ == '__main__':
    main()