from pathlib import Path
import pandas as pd
import numpy as np

import util
from data import snli_labels, SNLIDataset
from torch.utils.data import ConcatDataset

model_dir = Path('models/')
data_dir = Path('data/')
preds_filename = 'preds.jsonl'
nli_models = [
        'bert_snli', 
        'bert_snli_hypothesis-only', 

        'bert_snli+75k-llama3.2:3b', 
        'bert_snli+75k-llama3.2:3b_hypothesis-only',
        'bert_snli+75k-llama3.3:70b', 
        'bert_snli+75k-llama3.3:70b_hypothesis-only',
        'bert_snli+75k-deepseek-r1:70b', 
        'bert_snli+75k-deepseek-r1:70b_hypothesis-only',

        'bert_snli^75k-llama3.2:3b', 
        'bert_snli^75k-llama3.2:3b_hypothesis-only',
        'bert_snli^75k-llama3.3:70b', 
        'bert_snli^75k-llama3.3:70b_hypothesis-only',
        'bert_snli^75k-deepseek-r1:70b', 
        'bert_snli^75k-deepseek-r1:70b_hypothesis-only',

        'RoBERTa+SE_snli^75k-llama3.3:70b', 
        'RoBERTa+SE_snli^75k-llama3.3:70b_hypothesis-only',
        'RoBERTa+SE_snli^75k-deepseek-r1:70b', 
        'RoBERTa+SE_snli^75k-deepseek-r1:70b_hypothesis-only',
    ]

gen_models = ['llama3.2:3b', 'llama3.3:70b', 'deepseek-r1:70b']


def info_from_model_name(model_name):
    split = model_name.split('_')
    hypothesis_only = len(split) > 2
    model, train_data = split[:2]
    return model+'_'+train_data, hypothesis_only

# Load test datasets
print("Loading...")

print("\tPredictions...")
preds = []
for model_name in nli_models:
    print(f"\t\t{model_name}.")
    path = model_dir/model_name
    preds_ = pd.DataFrame.from_records(util.load_jsonl(path/preds_filename))
    assert not preds_['pairID'].duplicated().any()
    preds_['predicted_label'] = preds_[snli_labels].idxmax(axis=1) 
    preds_['model+data'], preds_['hyp_only'] = info_from_model_name(model_name)
    preds.append(preds_)
preds = pd.concat(preds)


def experiment1():
    print("\tSNLI test set.")
    df_snli = pd.DataFrame.from_records(util.load_jsonl(data_dir/'snli_1.0/snli_1.0_test.jsonl'))
    df_snli['dataset'] = 'snli'
    # df_snli = df_snli.drop(df_snli[(df_snli['gold_label'] == '-')].index)
    df_snli['true_label'] = df_snli['gold_label']

    print("\tGenerated test sets...")
    df_gens = []
    for model in gen_models:
        print(f"\t\t{model}.")
        df_gen = pd.DataFrame.from_records(util.load_jsonl(data_dir/f'generated/{model}_test.jsonl'))
        df_gen['dataset'] = model
        df_gen['true_label'] = df_gen['model_label']
        df_gens.append(df_gen)

    df_test = pd.concat([df_snli] + df_gens)
    assert not df_test['pairID'].duplicated().any()

    # Load predictions

    print("# Test data results (original SNLI & generated)")

    print("## Full results")

    df = pd.merge(preds, df_test, left_on='pairID', right_on='pairID', validate='many_to_one')

    df['correct'] = df['predicted_label'] == df['true_label']

    results = df.groupby(['dataset', 'true_label', 'model+data', 'hyp_only'])[['correct']].mean().rename({'correct': 'recall'}, axis=1)
    results['precision'] = df.groupby(['dataset', 'predicted_label', 'model+data', 'hyp_only'])[['correct']].mean()
    results['f1'] = 2*results['precision']*results['recall'] / (results['precision'] + results['recall'])

    print((results.unstack()*100).to_string(na_rep='-', float_format=lambda x: f'{x:.1f}'))
    print()

    print("## Macro-avg F1 summary (hyp only / full model)")
    print("Hypothesis only:")
    print((results['f1'].groupby(['dataset', 'model+data', 'hyp_only']).mean().unstack('dataset')[['snli'] + gen_models].xs(True, level='hyp_only') * 100).to_string(na_rep='-', float_format=lambda x: f'{x:.1f}'))
    print("Full:")
    print((results['f1'].groupby(['dataset', 'model+data', 'hyp_only']).mean().unstack('dataset')[['snli'] + gen_models].xs(False, level='hyp_only') * 100).to_string(na_rep='-', float_format=lambda x: f'{x:.1f}'))

    print("# LaTex tables")

    dfr = results['f1'].groupby(['dataset', 'model+data', 'hyp_only']).mean().unstack(level=['dataset', 'hyp_only'])[['snli']+gen_models]
    # display hypothesis-only differential instead of raw
    dfr.iloc[:, dfr.columns.get_level_values(1)==True] = dfr.xs(False, level='hyp_only', axis=1) - dfr.xs(True, level='hyp_only', axis=1)
    dfr = dfr.reset_index()

    def display_model_data(model_data):
        model, data = model_data.split('_')
        model = 'BERT' if model == 'bert' else model
        if 'deepseek-r1' in data:
            data_ = 'DS-R1'
        elif 'llama3.3' in data:
            data_ = 'LL3.3'
        elif 'llama3.2' in data:
            data_ = 'LL3.2'
        else:
            data_ = ''
        qual = {'^': '±', '+': '+'}.get(data[4]) if len(data) > 4 else ''
        
        return model, 'SNLI'+qual+data_

    dfr['model'], dfr['data'] = zip(*dfr['model+data'].apply(display_model_data))
    dfr = dfr.drop('model+data', axis=1)

    dfr = dfr.rename(columns={
        "snli": "SNLI", 
        'llama3.2:3b': 'Ll3.2', 
        'llama3.2:70b': 'Ll3.3', 
        'deepseek-r1:70b': 'DS-R1', 
        })

    dfr = dfr.set_index(['model', 'data'])

    with open('paper/tables/sanity-check.tex', 'w') as f:
        f.write(dfr.to_latex(na_rep='-', float_format=lambda x: f'{x*100:.1f}', escape=True))


def experiment2():
    print("# Inferred item performance")

    df_inf = pd.DataFrame.from_records(util.load_jsonl(data_dir/f'inferred/llama3.3:70b_test.jsonl'))

    df = pd.merge(df_inf, preds, left_on='pairID', right_on='pairID')
    df['label_R'] = df['label_R'].fillna('-')

    print(df[['inference_pattern', 'label_L', 'label_R']].value_counts(dropna=False))

    # create pattern-neutral names for the first and second items
    df['consistent_EI'] = df.apply(lambda x: x['predicted_label'] in x['possible_labels_EI'], axis=1)
    df['consistent_SC'] = df.apply(lambda x: x['predicted_label'] in x['possible_labels_SC'], axis=1)

    # get the predicted labels for the original items
    keys = ['model+data', 'hyp_only']
    df = pd.merge(df, preds[['pairID', 'model+data', 'hyp_only', 'predicted_label']], left_on=['pairID_L']+keys, right_on=['pairID']+keys, suffixes=('', '_L'))
    df = pd.merge(df, preds[['pairID', 'model+data', 'hyp_only', 'predicted_label']], left_on=['pairID_R']+keys, right_on=['pairID']+keys, suffixes=('', '_R'), how='left')

    # mark the inferred items where the model was correct on its anticedents 
    correct_L = df['predicted_label_L'] == df['label_L']
    correct_R = (df['predicted_label_R'] == df['label_R']) | (df['inference_pattern'] == 'ab-ba')
    df['anticedent_correct'] =  correct_L & correct_R

    print(df.groupby(['inference_pattern', 'label_L', 'label_R', 'model+data', 'hyp_only'])['anticedent_correct'].mean().to_string())

    def aggregate_results(df):
        gb = df.groupby(['model+data', 'hyp_only', 'inference_pattern', 'label_L', 'label_R'])
        results = gb[['predicted_label']].value_counts(normalize=True).unstack() 
        results['possible_labels_SC']  = gb[['possible_labels_SC']].agg('first')
        results['possible_labels_EI']  = gb[['possible_labels_EI']].agg('first')
        results['consistent_SC'] = gb[['consistent_SC']].mean() 
        results['consistent_EI'] = gb[['consistent_EI']].mean() 
        results['count'] = gb['pairID'].count()
        return results

    results1 = aggregate_results(df[df['anticedent_correct']])
    results1['filter_anticedent_correct'] = True
    results2 = aggregate_results(df)
    results2['filter_anticedent_correct'] = False
    results = pd.concat([results1, results2])

    print(results.to_string(na_rep='-', float_format=lambda x: f'{x:.1f}'))
    print()

    def format_possible_labels(possible_labels):
        pl = set(possible_labels)
        if pl == {'contradiction'}:
            return "C"
        elif pl == {'entailment'}:
            return "E"
        elif pl == {'neutral'}:
            return "N"
        elif pl == {'entailment', 'contradiction'}:
            return r"\lnot N"
        elif pl == {'entailment', 'neutral'}:
            return r"\lnot C"
        elif pl == {'contradiction', 'neutral'}:
            return r"\lnot E"
        elif pl == {'entailment', 'contradiction', 'neutral'}:
            return r"?"
        else:
            raise ValueError(f"Bad labels: {possible_labels}")

    def format_label(label):
        return label[0].upper()

    def format_pattern(pattern, label_L, label_R, possible_labels_SC, possible_labels_EI):
        parts = pattern.split('-')
        if len(parts) == 2:
            pattern_L, pattern_inf = parts
            pattern = format_label(label_L) + pattern_L
        elif len(parts) == 3:
            pattern_L, pattern_R, pattern_inf = parts
            pattern = format_label(label_L) + pattern_L + r"\land " +\
                    format_label(label_R) + pattern_R
        else:
            raise ValueError(f"Bad pattern: {pattern}")
        consequent_EI = format_possible_labels(possible_labels_EI) + pattern_inf
        consequent_SC = format_possible_labels(possible_labels_SC) + pattern_inf
        return mm(pattern), mm(parts[-1]), mm(consequent_SC), mm(consequent_EI)

    def mm(x):
        return '$' + x + '$'

    pattern_to_sources = {
        'ab-ac-bc': ('c', 'h', 'h'),
        'ab-bc-ac': ('c', 'h', 'g'),
        'ab-bc-ca': ('c', 'h', 'g'),
        'ab-ba':    ('c', 'h', '--')
    }

    results = results.reset_index()
    results.columns.name=''

    results['input items'], results['item'], results['SC'], results['EI'] = zip(*results.apply(lambda x: format_pattern(x['inference_pattern'], x['label_L'], x['label_R'], x['possible_labels_SC'], x['possible_labels_EI']), axis=1))
    results['$a$'], results['$b$'], results['$c$'] = zip(*results['inference_pattern'].apply(pattern_to_sources.get))

    results['consistent_EI'] = results.apply(lambda x: '--' if '?' in x['EI'] else x['consistent_EI'], axis=1)
    results['consistent_SC'] = results.apply(lambda x: '--' if '?' in x['SC'] else x['consistent_SC'], axis=1)
    results['EI'] = results.apply(lambda x: '--' if '?' in x['EI'] else x['EI'], axis=1)
    results['SC'] = results.apply(lambda x: '--' if '?' in x['SC'] else x['SC'], axis=1)

    results = results.rename(columns={
        "entailment": "E", 
        'contradiction': 'C', 
        'neutral': 'N',
        'consistent_SC': r"SC$\checkmark$",
        'consistent_EI': r"EI$\checkmark$",
        })

    dfr = results.set_index(['filter_anticedent_correct', 'model+data', 'hyp_only'])[["input items", "item", '$c$', "count", 'E', 'C', 'N', 'SC', r"SC$\checkmark$", 'EI', r"EI$\checkmark$"]]

    def write_latex_table(filename, section):
        level = ('filter_anticedent_correct', 'model+data', 'hyp_only')
        with open(filename, 'w') as f:
            f.write(dfr.xs(section, level=level).to_latex(index=False, na_rep='-', float_format=lambda x: f'{x*100:.1f}', escape=False))

    level = ('filter_anticedent_correct', 'model+data', 'hyp_only')
    dfr.xs((True, 'RoBERTa+SE_snli^75k-llama3.3:70b', False), level=level)[['count']]

    write_latex_table(
            'paper/tables/inferred_RoBERTa+SE_llama3.3:70b_all.tex', 
            (False, 'RoBERTa+SE_snli^75k-llama3.3:70b', False)
    )

    write_latex_table(
            'paper/tables/inferred_RoBERTa+SE_llama3.3:70b_anticedent-correct.tex', 
            (True, 'RoBERTa+SE_snli^75k-llama3.3:70b', False)
    )

    write_latex_table(
            'paper/tables/inferred_RoBERTa+SE_deepseek-r1:70b_all.tex', 
            (False, 'RoBERTa+SE_snli^75k-deepseek-r1:70b', False)
    )

    write_latex_table(
            'paper/tables/inferred_RoBERTa+SE_deepseek-r1:70b_anticedent-correct.tex', 
            (True, 'RoBERTa+SE_snli^75k-deepseek-r1:70b', False)
    )


if __name__ == '__main__':
    experiment1()
    # experiment2()
