import numpy as np
from matplotlib.colors import ListedColormap
from lm_polygraph.utils.manager import UEManager
import pandas as pd
import seaborn as sns
import re

cm = sns.color_palette("coolwarm", as_cmap=True)

models = ['mistral', 'llama', 'falcon']
metrics = ['AlignScoreOutputTarget',
           'Accuracy',
            'AlignScoreOutputTarget',
           'Accuracy',
           'AlignScoreInputOutput',
           'Comet',
           'Comet']
metric_pretty ={
            'AlignScoreOutputTarget': 'Align Score',
           'Accuracy': 'Accuracy',
           'AlignScoreInputOutput': 'Align Score',
           'Comet': 'Comet',

}

datasets_pretty = {
       'triviaqa':'Trivia',
    'mmlu':'MMLU',
    'coqa' : 'CoQa',
    'gsm8k': 'GSM8k',
    'xsum':'XSUM',
    'wmt14_fren':'WMT14FrEn',
    'wmt19_deen':'WMT19DeEn',
 
}

datasets = [
    'triviaqa',
    'mmlu',
    'coqa',
    'gsm8k',
    'xsum',
    'wmt14_fren',
    'wmt19_deen',
]

column_names = {
    'Greedy': '',
    'First Sample': 'Sample',
    'Best Sample': 'BestSample',
    # 'Best Normalized Sample': 'BestNormalizedSample',
    'MBR Sample': 'MbrSample',
}

model_names = {
    'mistral': 'Mistral7b-Base',
    'llama': 'Llama8b-Base',
    'falcon': 'Falcon7b-Base',
}

def postprocess_latex(latex):
    latex = latex.replace('_', '\_')
    latex = latex.replace('table', 'table*')
    latex = latex.splitlines()

    new_latex = []
    for i, line in enumerate(latex):
        if i == 0:
            new_latex.append(line)
            new_latex.append('\\footnotesize')
        elif line.startswith('\\begin{tabular}'):
            new_latex.append('\\begin{tabular}{llcccc}')
            new_latex.append('\\toprule')
            new_latex.append('Dataset & Metric & Greedy & Sample & Best Sample & Mbr Sample \\\\')
            new_latex.append('\\midrule')
        elif 'caption' in line:
            new_latex.append(line)
        elif line.startswith('\\end{tabular}'):
            new_latex.append('\\bottomrule')
            new_latex.append(line)
        elif re.match(r'^[^&]+ & [^&]+ &', line):
            new_latex.append(line + ' \\\\')
        else:
            new_latex.append(line)

    return '\n'.join(new_latex)

# Collect all rows across all models
all_rows = []

for model in models:
    base_dir = 'old_mbr'
    for metric_name, dataset in zip(metrics, datasets):
        man = UEManager.load(f'{base_dir}/{model}_{dataset}.man')
        row = {
            'Dataset': datasets_pretty[dataset],
            'Metric': metric_pretty[metric_name],
        }
        for column_name, metric_prefix in column_names.items():
            cur_metric_name = metric_prefix + metric_name
            metric_values = man.gen_metrics[('sequence', cur_metric_name)]
            row[column_name] = np.mean(metric_values)
        row['Model'] = model
        all_rows.append(row)

# Convert to DataFrame
df = pd.DataFrame(all_rows)

# Reorder columns
df = df[['Model', 'Dataset', 'Metric'] + list(column_names.keys())]

# Sort and group by model for LaTeX export
df = df.sort_values(by=['Model', 'Dataset'])

# Add row separators for model blocks
latex_lines = []
latex_lines.append('\\begin{table*}')
latex_lines.append('\\footnotesize')
latex_lines.append('\\centering')
latex_lines.append('\\caption{Base quality metrics for all models}')
latex_lines.append('\\begin{tabular}{llccccc}')
latex_lines.append('\\toprule')
latex_lines.append('Dataset & Metric & Greedy & Sample & BestSample & BestNormalizedSample & MbrSample \\\\')
latex_lines.append('\\midrule')

current_model = None
for _, row in df.iterrows():
    if row['Model'] != current_model:
        current_model = row['Model']
        latex_lines.append(f'\\rowcolor[gray]{{0.9}} \\multicolumn{{7}}{{c}}{{{model_names[current_model]}}} \\\\')
    latex_lines.append(
        f"{row['Dataset']} & {row['Metric']} & "
        f"{row['Greedy']:.3f} & {row['First Sample']:.3f} & {row['Best Sample']:.3f} & "
        f"{row['MBR Sample']:.3f} \\\\"
    )

latex_lines.append('\\bottomrule')
latex_lines.append('\\end{tabular}')
latex_lines.append('\\end{table*}')

# Save to file
latex_str = '\n'.join(latex_lines)
with open('old_mbr/all_models_base_quality.tex', 'w') as f:
    f.write(latex_str)
