import os
import pandas as pd
from lm_polygraph.utils.manager import UEManager

from lm_polygraph.estimators.greedy_supervised_cocoa import *
from lm_polygraph.estimators.max_probability import MaximumSequenceProbability
from lm_polygraph.estimators.perplexity import Perplexity
from lm_polygraph.estimators.token_entropy import MeanTokenEntropy
from lm_polygraph.estimators.greedy_semantic_average_ue_average_similarity import *

models = ["mistral","llama", "falcon"]
datasets = [
    "xsum","wmt14_fren", "wmt19_deen",  "coqa", "triviaqa", "mmlu","gsm8k"
]

# datasets = ["gsm8k", "mmlu", "triviaqa", "coqa","wmt19_deen","wmt14_fren","xsum"]
# Directory containing .man files
results_dir = 'old_mbr/final'

# Initialize dict for collecting results
summary = {}

# Loop over model-dataset pairs
rows = []


def get_quality_metrics(decoding='greedy'):
    if decoding=='greedy':
        quality_metrics = {
            'triviaqa': 'AlignScoreOutputTarget',
            'coqa': 'AlignScoreOutputTarget',
            'gsm8k': 'Accuracy',
            'wmt14_fren': 'Comet',
            'wmt19_deen': 'Comet',
            'mmlu': 'Accuracy',
            'xsum': 'AlignScoreInputOutput',
        }
    elif decoding=='sample':
        quality_metrics = {
            'triviaqa': 'BestSampleAlignScoreOutputTarget',
            'coqa': 'BestSampleAlignScoreOutputTarget',
            'gsm8k': 'BestSampleAccuracy',
            'wmt14_fren': 'BestSampleComet',
            'wmt19_deen': 'BestSampleComet',
            'mmlu': 'BestSampleAccuracy',
            'xsum': 'BestSampleAlignScoreInputOutput',
        }
    else:
        quality_metrics = {
            'triviaqa': 'MbrSampleAlignScoreOutputTarget',
            'coqa': 'MbrSampleAlignScoreOutputTarget',
            'gsm8k': 'MbrSampleAccuracy',
            'wmt14_fren': 'MbrSampleComet',
            'wmt19_deen': 'MbrSampleComet',
            'mmlu': 'MbrSampleAccuracy',
            'xsum': 'MbrSampleAlignScoreInputOutput',
        }

    return quality_metrics


def get_methods(type_='all', decoding='greedy'):
    if type_=='all':
        if decoding=='greedy':
            methods = { 
                'general_baselines': [
                    'MonteCarloSequenceEntropy',
                    'MonteCarloNormalizedSequenceEntropy',
                    'SemanticEntropy',
                    'DegMat_NLI_score_entail',
                    'EigValLaplacian_NLI_score_entail',
                    'SAR_t0.001',
                    'SupervisedCocoa',
                    'GreedyAveDissimilarity'
                ],
                'msp': [
                    'MaximumSequenceProbability',
                    'GreedySemanticEnrichedMaxprobAveDissimilarity',
                    'SupervisedCocoaMSP'

                ],
                'ppl': [
                    'Perplexity',
                    'GreedySemanticEnrichedPPLAveDissimilarity',
                    'SupervisedCocoaPPL'

                ],
                'mte': [
                    'MeanTokenEntropy',
                    'GreedySemanticEnrichedMTEAveDissimilarity',
                    'SupervisedCocoaMTE'
                ]
            }
        elif decoding=='sample':
            methods = { 
                'general_baselines': [
                    'MonteCarloSequenceEntropy',
                    'MonteCarloNormalizedSequenceEntropy',
                    'SemanticEntropy',
                    'DegMat_NLI_score_entail',
                    'EigValLaplacian_NLI_score_entail',
                    'SAR_t0.001',
                    'BestAveDissimilarity'
                ],
            'msp': [
                'BestSampledMaximumSequenceProbability',
                'BestSemanticEnrichedMaxprobAveDissimilarity',
                'BestSampledSupervisedCocoaMSP'

            ],
            'ppl': [
                'BestSampledPerplexity',
                'BestSemanticEnrichedPPLAveDissimilarity',
                'BestSampledSupervisedCocoaPPL'

            ],
            'mte': [
                'BestSampledMeanTokenEntropy',
                'BestSemanticEnrichedMTEAveDissimilarity',
                'BestSampledSupervisedCocoaMTE'
            ]
            }
        else:
            methods = { 
            'general_baselines': [
                'MonteCarloSequenceEntropy',
                'MonteCarloNormalizedSequenceEntropy',
                'SemanticEntropy',
                'DegMat_NLI_score_entail',
                'EigValLaplacian_NLI_score_entail',
                'SAR_t0.001',
                'MbrAveDissimilarity'
            ],
            'msp': [
                'MbrSampledMaximumSequenceProbability',
                'MbrSemanticEnrichedMaxprobAveDissimilarity',
                # 'MbrSampledSupervisedCocoaMSP'

            ],
            'ppl': [
                'MbrSampledPerplexity',
                'MbrSemanticEnrichedPPLAveDissimilarity',
                # 'MbrSampledSupervisedCocoaPPL'

            ],
            'mte': [
                'MbrSampledMeanTokenEntropy',
                'MbrSemanticEnrichedMTEAveDissimilarity',
                # 'MbrSampledSupervisedCocoaMTE'
            ]
        }
    elif type_ == 'formulations':
        if decoding=='mbr':
            methods = { 
                'msp': [
                    'MbrSumSemanticMaxprob',
                    'MbrSemanticEnrichedMaxprobTotalDissimilarity',
                    'MbrSemanticEnrichedMaxprobAveDissimilarityexp',
                    'MbrSemanticEnrichedMaxprobAveDissimilarity',

                ],
                'ppl': [
                    'MbrSumSemanticPPL',
                    'MbrSemanticEnrichedPPLTotalDissimilarity',
                    'MbrSemanticEnrichedPPLAveDissimilarityexp',
                    'MbrSemanticEnrichedPPLAveDissimilarity'

                ],
                'mte': [
                    'MbrSumSemanticMTE',
                    'MbrSemanticEnrichedMTETotalDissimilarity',
                    'MbrSemanticEnrichedMTEAveDissimilarity'
                ]
            }
        elif decoding =='greedy':
            methods = { 
                        'msp': [
                            'SumSemanticMaxprob',
                            'SemanticEnrichedMaxprobTotalDissimilarity',
                            'SemanticEnrichedMaxprobAveDissimilarityexp',
                            'SemanticEnrichedMaxprobAveDissimilarity',

                        ],
                        'ppl': [
                            'SumSemanticPPL',
                            'SemanticEnrichedPPLTotalDissimilarity',
                            'SemanticEnrichedPPLAveDissimilarityexp',
                            'SemanticEnrichedPPLAveDissimilarity'

                        ],
                        'mte': [
                            'SumSemanticMTE',
                            'SemanticEnrichedMTETotalDissimilarity',
                            'SemanticEnrichedMTEAveDissimilarity'
                        ]
                    }
        else:
            methods = { 
                        'msp': [
                            'BestSumSemanticMaxprob',
                            'BestSemanticEnrichedMaxprobTotalDissimilarity',
                            'BestSemanticEnrichedMaxprobAveDissimilarityexp',
                            'BestSemanticEnrichedMaxprobAveDissimilarity',

                        ],
                        'ppl': [
                            'BestSumSemanticPPL',
                            'BestSemanticEnrichedPPLTotalDissimilarity',
                            'BestSemanticEnrichedPPLAveDissimilarityexp',
                            'BestSemanticEnrichedPPLAveDissimilarity'

                        ],
                        'mte': [
                            'BestSumSemanticMTE',
                            'BestSemanticEnrichedMTETotalDissimilarity',
                            'BestSemanticEnrichedMTEAveDissimilarity'
                        ]
                    }
        
    elif type_=='similarity':
        if decoding=='mbr':
            methods = { 
                    'msp': [
                        'MbrSemanticEnrichedMaxprobAveDissimilarity_align_semantic_matrix',
                        'MbrSemanticEnrichedMaxprobAveDissimilarity_rouge_semantic_matrix',
                        'MbrSemanticEnrichedMaxprobAveDissimilarity_semantic_matrix_entail',
                        'MbrSemanticEnrichedMaxprobAveDissimilarity_sample_sentence_similarity',

                    ],
                    'ppl': [
                        'MbrSemanticEnrichedPPLAveDissimilarity_align_semantic_matrix',
                        'MbrSemanticEnrichedPPLAveDissimilarity_rouge_semantic_matrix',
                        'MbrSemanticEnrichedPPLAveDissimilarity_semantic_matrix_entail',
                        'MbrSemanticEnrichedPPLAveDissimilarity_sample_sentence_similarity'

                    ],
                    'mte': [
                        'MbrSemanticEnrichedMTEAveDissimilarity_align_semantic_matrix',
                        'MbrSemanticEnrichedMTEAveDissimilarity_rouge_semantic_matrix',
                        'MbrSemanticEnrichedMTEAveDissimilarity_semantic_matrix_entail',
                        'MbrSemanticEnrichedMTEAveDissimilarity_sample_sentence_similarity'
                    ]
                }
        elif decoding=='greedy':
            methods = { 
                    'msp': [
                        'SemanticEnrichedMaxprobAveDissimilarity_align_semantic_matrix',
                        'SemanticEnrichedMaxprobAveDissimilarity_rouge_semantic_matrix',
                        'SemanticEnrichedMaxprobAveDissimilarity_semantic_matrix_entail',
                        'SemanticEnrichedMaxprobAveDissimilarity_sample_sentence_similarity',

                    ],
                    'ppl': [
                        'SemanticEnrichedPPLAveDissimilarity_align_semantic_matrix',
                        'SemanticEnrichedPPLAveDissimilarity_rouge_semantic_matrix',
                        'SemanticEnrichedPPLAveDissimilarity_semantic_matrix_entail',
                        'SemanticEnrichedPPLAveDissimilarity_sample_sentence_similarity'

                    ],
                    'mte': [
                        'SemanticEnrichedMTEAveDissimilarity_align_semantic_matrix',
                        'SemanticEnrichedMTEAveDissimilarity_rouge_semantic_matrix',
                        'SemanticEnrichedMTEAveDissimilarity_semantic_matrix_entail',
                        'SemanticEnrichedMTEAveDissimilarity_sample_sentence_similarity'
                    ]
                }
        else:
            methods = { 
                    'msp': [
                        'BestSemanticEnrichedMaxprobAveDissimilarity_align_semantic_matrix',
                        'BestSemanticEnrichedMaxprobAveDissimilarity_rouge_semantic_matrix',
                        'BestSemanticEnrichedMaxprobAveDissimilarity_semantic_matrix_entail',
                        'BestSemanticEnrichedMaxprobAveDissimilarity_sample_sentence_similarity',

                    ],
                    'ppl': [
                        'BestSemanticEnrichedPPLAveDissimilarity_align_semantic_matrix',
                        'BestSemanticEnrichedPPLAveDissimilarity_rouge_semantic_matrix',
                        'BestSemanticEnrichedPPLAveDissimilarity_semantic_matrix_entail',
                        'BestSemanticEnrichedPPLAveDissimilarity_sample_sentence_similarity'

                    ],
                    'mte': [
                        'BestSemanticEnrichedMTEAveDissimilarity_align_semantic_matrix',
                        'BestSemanticEnrichedMTEAveDissimilarity_rouge_semantic_matrix',
                        'BestSemanticEnrichedMTEAveDissimilarity_semantic_matrix_entail',
                        'BestSemanticEnrichedMTEAveDissimilarity_sample_sentence_similarity'
                    ]
                }



    return methods


for model in models:
    for dataset in datasets:
        results_dir = ''
        file_path = os.path.join(results_dir, f"{model}_{dataset}.man")
        if not os.path.exists(file_path):
            print(f"Missing: {file_path}")
            continue

        man = UEManager.load(file_path)
        methods= get_methods(type_='similarity',decoding='mbr')
        quality_metrics = get_quality_metrics(decoding='mbr')
        for type, block_methods in methods.items():
            for method in block_methods:  
                prr = man.metrics[('sequence', str(method), quality_metrics[dataset], 'prr_0.5_normalized')]
                rows.append({
                    "Model": model,
                    "Method": str(method),
                    "Dataset": dataset,
                    "PRR": float(prr)
                })

df_long = pd.DataFrame(rows)

from pandas.api.types import CategoricalDtype

# Flatten all methods in defined order
method_order = [str(m) for block in methods.values() for m in block]
method_cat_type = CategoricalDtype(method_order, ordered=True)

# Apply method order before pivot
df_long["Method"] = df_long["Method"].astype(method_cat_type)
df_pivot = df_long.pivot_table(index=["Model", "Method"], columns="Dataset", values="PRR")

method_order = [str(m) for block in get_methods(type_='similarity', decoding='mbr').values() for m in block]
row_index_order = pd.MultiIndex.from_tuples(
    [(m, method) for m in models for method in method_order],
    names=["Model", "Method"]
)
df_pivot = df_pivot.reindex(index=row_index_order)


df_pivot = df_pivot.reindex(columns=datasets)


dataset_pretty = {
    "xsum": "XSum",
    "wmt14_fren": "WMT14FrEn",
    "wmt19_deen": "WMT19DeEn",
    "coqa": "CoQa",
    "triviaqa": "Trivia",
    "mmlu": "MMLU",
    "gsm8k": "GSM8k",
}

models_pretty ={
    'mistral':'Mistral7b-Base',
    'llama':'Llama8b-Base',
    'falcon':'Falcon7b-Base'
}

method_mapping = {
    "MonteCarloSequenceEntropy": "MCSE",
    "MonteCarloNormalizedSequenceEntropy": "MCNSE",
    "SemanticEntropy": "Semantic Entropy",
    "SAR_t0.001": "SAR",
    "BestSemanticDensity": "Semantic Density",
    "GreedySemanticDensity": "Semantic Density",
    "SampledSemanticDensity": "Semantic Density",
    "DegMat_NLI_score_entail": "DegMat",
    "EigValLaplacian_NLI_score_entail": "EigValLaplacian",
    "BestSampledMaximumSequenceProbability": "MSP",
    "SampledMaximumSequenceProbability": "MSP",
    "MaximumSequenceProbability": "MSP",
    "BestSampledPerplexity": "PPL",
    "SampledPerplexity": "PPL",
    "Perplexity": "PPL",
    "SemanticEnrichedMaxprobAveDissimilarity": r"$\text{CoCoA}_{MSP}$",
    "SemanticEnrichedPPLAveDissimilarity": r"$\text{CoCoA}_{PPL}$",
    "GreedySemanticEnrichedMTEAveDissimilarity": r"$\text{CoCoA}_{MTE}$",
    "GreedySemanticEnrichedPPLAveDissimilarity": r"$\text{CoCoA}_{PPL}$",
    "GreedySemanticEnrichedMaxprobAveDissimilarity": r"$\text{CoCoA}_{MSP}$",
    "BestSampledMeanTokenEntropy": "MTE",
    "SampledMeanTokenEntropy": "MTE",
    "MeanTokenEntropy": "MTE",
    "BestSemanticEnrichedMaxprobAveDissimilarity" : r"$\text{CoCoA}_{MSP}$",
    "BestSemanticEnrichedPPLAveDissimilarity" : r"$\text{CoCoA}_{PPL}$",
    "BestSemanticEnrichedMTEAveDissimilarity" : r"$\text{CoCoA}_{MTE}$",
    "SemanticEnrichedMTEAveDissimilarity": r"$\text{CoCoA}_{MTE}$",
    'MbrSampledMaximumSequenceProbability': 'MSP',
    'MbrSemanticEnrichedMaxprobAveDissimilarity': r"$\text{CoCoA}_{MTE}$",
    'MbrSampledPerplexity': 'PPL',
    'MbrSemanticEnrichedPPLAveDissimilarity': r"$\text{CoCoA}_{PPL}$",
    'MbrSampledMeanTokenEntropy': 'MTE',
    'MbrSemanticEnrichedMTEAveDissimilarity': r"$\text{CoCoA}_{MTE}$",
    'MbrSampledSupervisedCocoaPPL':'SupervisedCocoaPPL',
    'MbrSampledSupervisedCocoaMTE':'SupervisedCocoaMTE',
    'MbrSampledSupervisedCocoaMSP':'SupervisedCocoaMSP',
    'MbrAveDissimilarity':'Dissimilarity',
    'MbrSumSemanticMTE': r"$\text{AdditiveCoCoA}_{MTE}$",
    'MbrSemanticEnrichedMTETotalDissimilarity': r"$\text{FullSampleCoCoA}_{MTE}$",
    'MbrSumSemanticPPL': r"$\text{AdditiveCoCoA}_{PPL}$",
    'MbrSemanticEnrichedPPLTotalDissimilarity': r"$\text{FullSampleCoCoA}_{PPL}$",
    'MbrSemanticEnrichedPPLAveDissimilarityexp' :r"$\text{ProbCoCoA}_{PPL}$",
    'MbrSumSemanticMaxprob': r"$\text{AdditiveCoCoA}_{MSP}$",
    'MbrSemanticEnrichedMaxprobTotalDissimilarity': r"$\text{FullSampleCoCoA}_{MSP}$",
    'MbrSemanticEnrichedMaxprobAveDissimilarityexp':r"$\text{ProbCoCoA}_{MSP}$",
    'MbrSemanticEnrichedMaxprobAveDissimilarity_rouge_semantic_matrix' : 'RougeL',
    'MbrSemanticEnrichedPPLAveDissimilarity_rouge_semantic_matrix' : 'RougeL',
    'MbrSemanticEnrichedMTEAveDissimilarity_rouge_semantic_matrix' : 'RougeL',
    'MbrAveDissimilarity_rouge_semantic_matrix' : 'RougeL',
    'MbrSemanticEnrichedMaxprobAveDissimilarity_align_semantic_matrix' : 'AlignScore',
    'MbrSemanticEnrichedPPLAveDissimilarity_align_semantic_matrix' : 'AlignScore',
    'MbrSemanticEnrichedMTEAveDissimilarity_align_semantic_matrix' : 'AlignScore',
    'MbrAveDissimilarity_align_semantic_matrix' : 'AlignScore',
    'MbrSemanticEnrichedMaxprobAveDissimilarity_semantic_matrix_entail' : 'NLI',
    'MbrSemanticEnrichedPPLAveDissimilarity_semantic_matrix_entail' : 'NLI',
    'MbrSemanticEnrichedMTEAveDissimilarity_semantic_matrix_entail' : 'NLI',
    'MbrAveDissimilarity_semantic_matrix_entail' : 'NLI',
    'MbrSemanticEnrichedMaxprobAveDissimilarity_sample_sentence_similarity' : 'CrossEncoder',
    'MbrSemanticEnrichedPPLAveDissimilarity_sample_sentence_similarity' : 'CrossEncoder',
    'MbrSemanticEnrichedMTEAveDissimilarity_sample_sentence_similarity' : 'CrossEncoder',
    'MbrAveDissimilarity_sample_sentence_similarity' : 'CrossEncoder',

}

latex_lines = [
    "\\begin{tabular}{l" + "r" * len(df_pivot.columns) + "}",
    "\\toprule",
    "    \\multirow{2}{*}{\\textbf{Method}}  & \\multicolumn{" + str(len(df_pivot.columns)) + "}{c}{\\textbf{Dataset}}  \\\\",
    "      \\cmidrule(lr){2-" + str(len(df_pivot.columns)+1) + "}  \\\\",
    "  & " + " & ".join(dataset_pretty.get(col, col) for col in df_pivot.columns) + " \\\\",
    "  \\midrule",
]

grouped = df_pivot.groupby(level=0, sort=False)

for model_idx, (model, group_df) in enumerate(grouped):
    latex_lines.append(f"\\rowcolor[gray]{{0.9}} & \\multicolumn{{{len(df_pivot.columns)}}}{{c}}{{{models_pretty[model]}}} \\\\")
    group_df.index = group_df.index.droplevel(0)
    best_per_col = group_df.idxmax()
    second_best_per_col = group_df.apply(lambda col: col.nlargest(2).index[-1] if len(col.dropna()) >= 2 else None)

    for i, (method, row) in enumerate(group_df.iterrows()):
        method_display = method_mapping.get(method, method)
        row_str = [method_display]
        for col in df_pivot.columns:
            val = row[col]
            if pd.isna(val):
                cell = ""
            else:
                if method == best_per_col[col]:
                    cell = f"\\textbf{{{val:.3f}}}"
                elif method == second_best_per_col[col]:
                    cell = f"\\underline{{{val:.3f}}}"
                else:
                    cell = f"{val:.3f}"
            row_str.append(cell)
        latex_lines.append(" & ".join(row_str) + " \\\\")
        method_blocks = get_methods(type_='similarity',decoding='mbr')
        method_to_block = {m: block for block, lst in method_blocks.items() for m in lst}
        curr_block = method_to_block.get(method, None)
        next_method = group_df.index[i + 1] if i + 1 < len(group_df) else None
        next_block = method_to_block.get(next_method, None)

        if curr_block != next_block and next_method is not None:
            latex_lines.append("  \\midrule")

latex_lines.append("\\bottomrule")
latex_lines.append("\\end{tabular}")

latex_output = "\n".join(latex_lines)
print(latex_output)
