import numpy as np
from matplotlib.colors import ListedColormap
from lm_polygraph.utils.manager import UEManager
import tabulate
import pandas as pd
from functools import partial
import seaborn as sns
import re
import matplotlib.pyplot as plt
import argparse

cm = sns.color_palette("coolwarm", as_cmap=True)

models = ['llama','mistral', 'falcon']

def get_metrics(args):
    if args.do_sample:
        if args.sample_strategy == 'first':
            ats_metrics = ['SampleRouge_rougeL']
            nmt_metrics = ['SampleComet']
            short_qa_metrics = ['SampleAccuracy']
            long_qa_metrics = ['SampleAlignScoreOutputTarget']
        elif args.sample_strategy == 'Mbr':
            ats_metrics = ['MbrSampleAlignScoreInputOutput']
            nmt_metrics = ['MbrSampleComet']
            short_qa_metrics = ['MbrSampleAccuracy']
            long_qa_metrics = ['MbrSampleAlignScoreOutputTarget']
        elif args.sample_strategy == 'Mbr_normalized':
            ats_metrics = ['MbrNormalizedSampleRouge_rougeL']
            nmt_metrics = ['MbrNormalizedSampleComet']
            short_qa_metrics = ['MbrNormalizedSampleAccuracy']
            long_qa_metrics = ['MbrNormalizedSampleAlignScoreOutputTarget']
        elif args.sample_strategy == 'mbr':
            ats_metrics = ['MbrSampleRouge_rougeL']
            nmt_metrics = ['MbrSampleComet']
            short_qa_metrics = ['MbrSampleAccuracy']
            long_qa_metrics = ['MbrSampleAlignScoreOutputTarget']

        else:
            raise ValueError(f'Invalid sample strategy: {args.sample_strategy}')
    else:
        ats_metrics = ['AlignScoreInputOutput']
        nmt_metrics = ['Comet']
        short_qa_metrics = ['Accuracy']
        long_qa_metrics = ['AlignScoreOutputTarget']

    return ats_metrics, nmt_metrics, short_qa_metrics, long_qa_metrics

def get_methods(args):
    if not args.do_sample:
        methods = { 
            'general_baselines': [
                'MonteCarloSequenceEntropy',
                'MonteCarloNormalizedSequenceEntropy',
                'SemanticEntropy',
                # 'CEDegMat',
                'DegMat_NLI_score_entail',
                'EigValLaplacian_NLI_score_entail',
                'SAR_t0.001'
            ],
            'msp': [
                'MaximumSequenceProbability',
                'GreedySemanticEnrichedMaxprobAveDissimilarity',
                'SupervisedCocoaMSP'

            ],
            'ppl': [
                'Perplexity',
                'GreedySemanticEnrichedPPLAveDissimilarity',
                'SupervisedCocoaPPL'

            ],
            'mte': [
                'MeanTokenEntropy',
                'GreedySemanticEnrichedMTEAveDissimilarity',
                'SupervisedCocoaMTE'
            ]
        }
    else:
        methods = { 
            'general_baselines': [
                'MonteCarloSequenceEntropy',
                'MonteCarloNormalizedSequenceEntropy',
                'SemanticEntropy',
                # 'CEDegMat',
                'DegMat_NLI_score_entail',
                'EigValLaplacian_NLI_score_entail',
                'SAR_t0.001'
            ],
            'msp': [
                'MbrSampledMaximumSequenceProbability',
                'MbrSemanticEnrichedMaxprobAveDissimilarity',
                'MbrSampledSupervisedCocoaMSP'

            ],
            'ppl': [
                'MbrSampledPerplexity',
                'MbrSemanticEnrichedPPLAveDissimilarity',
                'MbrSampledSupervisedCocoaPPL'

            ],
            'mte': [
                'MbrSampledMeanTokenEntropy',
                'MbrSemanticEnrichedMTEAveDissimilarity',
                'MbrSampledSupervisedCocoaMTE'
            ]
        }
    # if args.exclude_ss:
    #     for key, value in methods.items():
    #         changed_methods = []
    #         for method in value:
    #             if method not in single_sequence_methods:
    #                 changed_methods.append(method)
    #         methods[key] = changed_methods

    # if args.do_sample:
    #     change_methods = single_sequence_methods

    #     for key, value in methods.items():
    #         changed_methods = []
    #         for method in value:
    #             if method in change_methods:
    #                 changed_methods.append(f'Sampled{method}')
    #             else:
    #                 changed_methods.append(method)
    #         methods[key] = changed_methods

    return methods



def get_tasks():

    tasks = { 
        'qa': [
            'triviaqa',
            'mmlu',
            'coqa',
            'gsm8k',
        ],
        'ats': [
            'wmt_14_fren',
            'wmt_19_deen',
        ],
        'sum': [
            'xsum',
        ]
    }
    return tasks    

def parse_args():
    parser = argparse.ArgumentParser()
    # boolean argument do_sample with default value of False
    parser.add_argument('--do_sample', action='store_true', default=False)
    parser.add_argument('--exclude_ss', action='store_true')
    parser.add_argument('--sample_strategy', default='mbr')
    return parser.parse_args()

def main():
    args = parse_args()

    # dict_results = {}
    results ={}
    for model in models:
        #if args.do_sample:
        if args.do_sample:
            base_dir = 'old_mbr'
        else:
            base_dir = 'old_mbr'
        #else:
        #    base_dir = 'greedy_metric_mans/log_exp'
        tex_prefix = 'final_table' if args.do_sample else 'greedy'

        methods_dict = get_methods(args)
      
        methods_dict["general_baselines"] = methods_dict["general_baselines"] + ['MbrAveDissimilarity']

        tasks_dict = get_tasks()
        ats_metrics, nmt_metrics, short_qa_metrics, long_qa_metrics = get_metrics(args)

        def safe_load(man_path):
            try:
                return UEManager.load(man_path)
            except Exception:
                return None

        trivia_man = safe_load(f'{base_dir}/{model}_triviaqa.man')
        mmlu_man = safe_load(f'{base_dir}/{model}_mmlu.man')
        coqa_man = safe_load(f'{base_dir}/{model}_coqa.man')
        gsm8k_man = safe_load(f'{base_dir}/{model}_gsm8k.man')
        xsum_man = safe_load(f'{base_dir}/{model}_xsum.man')
        wmt_14_fren_man = safe_load(f'{base_dir}/{model}_wmt14_fren.man')
        wmt_19_deen_man = safe_load(f'{base_dir}/{model}_wmt19_deen.man')


        for _, methods in methods_dict.items():
            group_rows = {}
            # Translation
            for method in methods:
                method_row = []
                if method not in results:
                    results[method] = {}
                if model not in results[method]:
                    results[method][model] = {}

                for metric in nmt_metrics:
                    print(nmt_metrics)
                    try:
                        prr = wmt_14_fren_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')
                    try:
                        prr = wmt_19_deen_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')
                # Replace averaging with safe fallback
                valid_scores = [s for s in method_row if isinstance(s, (int, float))]
                results[method][model]["nmt"] = np.mean(valid_scores) if valid_scores else "-"
                # print(f"{model}, {method}, nmt: {np.mean(task_performance)}")
            
            # Summ
        for _, methods in methods_dict.items():
            for method in methods:
                method_row = []
                if method not in results:
                    results[method] = {}
                if model not in results[method]:
                    results[method][model] = {}

                for metric in ats_metrics:
                    try:
                        prr = xsum_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')
                valid_scores = [v for v in method_row if isinstance(v, (int, float))]
                results[method][model]["sum"] = np.mean(valid_scores) if valid_scores else "-"
        
        for _, methods in methods_dict.items():
            for method in methods:
                method_row = []
                if method not in results:
                    results[method] = {}
                if model not in results[method]:
                    results[method][model] = {}

                for metric in long_qa_metrics:
                    try:
                        prr = coqa_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')
                    try:
                        prr = trivia_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')

                for metric in short_qa_metrics:
                    try:
                        prr = mmlu_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')
                    try:
                        prr = gsm8k_man.metrics[('sequence', method, metric, 'prr_0.5_normalized')]
                        method_row.append(prr)
                    except:
                        method_row.append('-')

                valid_scores = [v for v in method_row if isinstance(v, (int, float))]
                results[method][model]["qa"] = np.mean(valid_scores) if valid_scores else "-"

    # print(results)
    tasks = ['qa', 'nmt', 'sum']

    for model in models:
        for task in tasks:
            task_scores = {}
            for category, method_list in methods_dict.items():
                for method in method_list:
                    task_scores[method] = results[method][model][task]

            # Filter numeric scores only
            numeric_scores = {k: v for k, v in task_scores.items() if isinstance(v, (int, float, np.float32, np.float64))}

            sorted_methods = sorted(numeric_scores.items(), key=lambda x: x[1], reverse=True)

            Mbr_method, Mbr_score = sorted_methods[0] if len(sorted_methods) > 0 else (None, None)
            second_Mbr_method, second_Mbr_score = sorted_methods[1] if len(sorted_methods) > 1 else (None, None)

            for method, score in task_scores.items():
                if isinstance(score, str) and score == "-":
                    results[method][model][task] = "-"
                elif method == Mbr_method:
                    results[method][model][task] = f"\\textbf{{{score:.3f}}}"
                elif method == second_Mbr_method:
                    results[method][model][task] = f"\\underline{{{score:.3f}}}"
                else:
                    results[method][model][task] = f"{score:.3f}"

    header = """
    \\begin{table*}[th!]
    \\centering
    \\renewcommand{\\arraystretch}{1.2} % Adjust row height
    \scalebox{0.85}{
    \\begin{tabular}{lccccccccc}
    \\bottomrule
    \\textbf{Metric} & \multicolumn{3}{c}{\\textbf{Llama}} & \multicolumn{3}{c}{\\textbf{Mistral}} & \multicolumn{3}{c}{\\textbf{Falcon}} \\\\  
    \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10}
    & \\textbf{QA} & \\textbf{NMT} & \\textbf{SUM} 
    & \\textbf{QA} & \\textbf{NMT} & \\textbf{SUM}  
    & \\textbf{QA} & \\textbf{NMT} & \\textbf{SUM}  \\\\
    \midrule
    """
    end_txt="""
    \\bottomrule
    \\end{tabular}}
    \\caption{Results for Evaluated Sequence - MBR Sample: Mean PRR across datasets for each task. The Mbr performing method is in bold, and the second-Mbr is underscored. Arrows indicate improvement in CoCoA over the base version.}
    \\label{tab:Mbr_sample_results}
    \\end{table*}"""


    method_mapping ={
    'MonteCarloSequenceEntropy': 'MCSE',
    'MonteCarloNormalizedSequenceEntropy': 'MCNSE',
    'SemanticEntropy': 'Semantic Entropy',
    'CEDegMat': 'CEDegMat',
    'SAR_t0.001': 'SAR',
    'DegMat_NLI_score_entail': 'DegMat',
    'EigValLaplacian_NLI_score_entail': 'EigValLaplacian',
    'MaximumSequenceProbability': 'MSP',
    'GreedySemanticEnrichedMaxprobAveDissimilarity': 'MSP',
    'Perplexity': 'Perplexity',
    'GreedySemanticEnrichedPPLAveDissimilarity': 'Perplexity',
    'MeanTokenEntropy': 'MeanTokenEntropy',
    'GreedySemanticEnrichedMTEAveDissimilarity': 'MeanTokenEntropy',
    'SupervisedCocoaPPL':'SupervisedCocoaPPL',
    'SupervisedCocoaMTE':'SupervisedCocoaMTE',
    'SupervisedCocoaMSP':'SupervisedCocoaMSP',
        'MbrSampledMaximumSequenceProbability': 'MSP',
    'MbrSemanticEnrichedMaxprobAveDissimilarity': 'MSP',
    'MbrSampledPerplexity': 'Perplexity',
    'MbrSemanticEnrichedPPLAveDissimilarity': 'Perplexity',
    'MbrSampledMeanTokenEntropy': 'MeanTokenEntropy',
    'MbrSemanticEnrichedMTEAveDissimilarity': 'MeanTokenEntropy',
        'MbrSampledSupervisedCocoaPPL':'SupervisedCocoaPPL',
    'MbrSampledSupervisedCocoaMTE':'SupervisedCocoaMTE',
    'MbrSampledSupervisedCocoaMSP':'SupervisedCocoaMSP',
    'GreedyAveDissimilarity':'Dissimilarity',
    'AveDissimilarity':'Dissimilarity',
    'MbrAveDissimilarity':'Dissimilarity',
    'SampledSupervisedCocoa':'SampledSupervisedCocoa',
    'SupervisedCocoa':'SupervisedCocoa'
    }

    rows = []
    for _, methods in methods_dict.items():
        for method in methods:
            # Replace the short method name with the full method name
            full_method_name = method_mapping[method]  # Default to the original if no mapping is found
            if "Enriched" in method:
                row = "$\\text{CoCoA}_" + f"{full_method_name}$" 
            else:
                row = "$\\text{" + f"{full_method_name}" + "}$"
            for model in models:
                for task in tasks:
                    if "Enriched" in method:
                        row = row + " & " + results[method][model][task] + "  \\(\\uparrow\\)  " 
                    else:
                        row = row + " & " + results[method][model][task]
            row = row + " \\\\"  # Add the newline for LaTeX table formatting
            if method == methods[-1] and _!='mte':
                row = row + " \\midrule"
            rows.append(row)
    
    combined_text = header + "\n".join(rows) + end_txt
    output_file = f"{args.do_sample}_output_mbr.txt"
    with open(output_file, "w") as file:
        file.write(combined_text)


    
if __name__ == '__main__':
    main()