import pandas as pd
from util import load_jsonl, write_jsonl
from data import SNLIDataset
from pathlib import Path

def possible_labels_EI(pattern, labels):
    """
    Meta-inferentially consistent labels under the existential import reading.
    """

    match (pattern, labels):

        case ('ab-ba', ('contradiction', )):
            return {'contradiction', 'neutral'}
        case ('ab-ba', ('entailment', )):
            return {'entailment', 'neutral'}
        case ('ab-ba', ('neutral', )):
            return {'contradiction', 'entailment', 'neutral'}
        case ('ab-ba', (_, )):
            return None

        case ('ab-bc-ac', ('entailment', 'entailment')):
            return {'entailment'}
        case ('ab-bc-ac', ('entailment', 'contradiction')):
            return {'contradiction'}
        case ('ab-bc-ac', ('neutral', 'entailment')):
            return {'entailment', 'neutral'}
        case ('ab-bc-ac', ('neutral', 'contradiction')):
            return {'neutral', 'contradiction'}
        case ('ab-bc-ac', (_, _)):
            return None 

        case ('ab-bc-ca', ('contradiction', 'neutral')):
            return {'entailment', 'neutral', 'contradiction'}
        case ('ab-bc-ca', ('entailment', 'entailment')):
            return {'entailment', 'neutral'}
        case ('ab-bc-ca', ('entailment', 'contradiction')):
            return {'contradiction', 'neutral'}
        case ('ab-bc-ca', ('neutral', 'entailment')):
            return {'entailment', 'neutral', 'contradiction'}
        case ('ab-bc-ca', (_, _)):
            return None

        case ('ab-ac-bc', ('entailment', 'contradiction')):
            return {'contradiction', 'neutral'}
        case ('ab-ac-bc', ('entailment', 'neutral')):
            return {'neutral'} 
        case ('ab-ac-bc', ('neutral', 'entailment')):
            return {'entailment', 'neutral'}
        case ('ab-ac-bc', ('neutral', 'contradiction')):
            return {'contradiction', 'neutral'}
        case ('ab-ac-bc', (_, _)):
            return None

        case (_, _):
            raise ValueError(f"No rule for pattern {pattern} and labels {labels}.")

def possible_labels_SC(pattern, labels):
    """
    Meta-inferentially consistent labels under the strict conditional reading.
    """

    match (pattern, labels):

        case ('ab-ba', ('contradiction', )):
            return {'contradiction'}
        case ('ab-ba', ('entailment', )):
            return {'entailment', 'contradiction', 'neutral'}
        case ('ab-ba', ('neutral', )):
            return {'entailment', 'neutral'}

        case ('ab-bc-ca', ('contradiction', 'neutral')):
            return {'neutral', 'contradiction'}
        case ('ab-bc-ca', ('entailment', 'entailment')):
            return {'entailment', 'contradiction', 'neutral'}
        case ('ab-bc-ca', ('entailment', 'contradiction')):
            return {'contradiction'}
        case ('ab-bc-ca', ('neutral', 'entailment')):
            return {'entailment', 'neutral'}

        case ('ab-ac-bc', ('entailment', 'contradiction')):
            return {'entailment', 'contradiction', 'neutral'}

        case (_, _):
            return possible_labels_EI(pattern, labels)

def infer_items(pattern, df_L, df_R, sent1, sent2, right_on=None, left_on=None,):
    if df_R is None:
        df = df_L
        cols = ['label', 'pairID', 'sentence1', 'sentence2']
        df = df.rename(columns=dict(zip(cols, map(lambda x: x + '_L', cols))))
        df['pairID'] = df['pairID_L'] + '__' + pattern
        df['possible_labels_EI'] = df.apply(lambda x: possible_labels_EI(pattern, (x['label_L'], )), axis=1)
        df['possible_labels_SC'] = df.apply(lambda x: possible_labels_SC(pattern, (x['label_L'], )), axis=1)
    else:
        df = pd.merge(df_L, df_R, left_on=left_on, right_on=right_on, how='right', suffixes=['_L', '_R'])
        df = df[df['pairID_L'] != df['pairID_R']]
        df['pairID'] = df['pairID_L'] + '__'  + df['pairID_R'] + '__' + pattern
        df['possible_labels_EI'] = df.apply(lambda x: possible_labels_EI(pattern, (x['label_L'], x['label_R'])), axis=1)
        df['possible_labels_SC'] = df.apply(lambda x: possible_labels_SC(pattern, (x['label_L'], x['label_R'])), axis=1)
    df = df[df['possible_labels_EI'].notna() | df['possible_labels_SC'].notna() ] # filter to the situations we have inference values for
    df['inference_pattern'] = pattern
    df['sentence1'] = df[sent1] # get the correct senteces for the inferred item 
    df['sentence2'] = df[sent2] # this could be infered from the pattern but i'm too lazy for that...
    return df
       

if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("data_dir", type=Path)
    parser.add_argument("gen_model", type=str)
    args = parser.parse_args()

    inferred_dir = args.data_dir/'inferred'
    inferred_dir.mkdir(exist_ok=True)

    snli = pd.DataFrame.from_records(load_jsonl(args.data_dir/'snli_1.0/snli_1.0_test.jsonl'))
    snli['label'] = snli['gold_label']
    snli = snli[['pairID', 'captionID', 'sentence1', 'sentence2', 'label']]

    gen  = pd.DataFrame.from_records(load_jsonl(args.data_dir/f'generated/{args.gen_model}_test.jsonl'))
    gen['label'] = gen['model_label']
    gen = gen[['pairID', 'premiseID', 'sentence1', 'sentence2', 'label']]


    df = pd.concat([
        infer_items('ab-bc-ac', snli, gen, 'sentence1_L', 'sentence2_R', left_on='pairID', right_on='premiseID'),
        infer_items('ab-bc-ca', snli, gen, 'sentence2_R', 'sentence1_L', left_on='pairID', right_on='premiseID'),
        infer_items('ab-ba', snli, None, 'sentence2_L', 'sentence1_L'),
        infer_items('ab-ac-bc', snli, snli, 'sentence2_L', 'sentence2_R', left_on='captionID', right_on='captionID'),
    ], axis=0)
    df = df.drop(['premiseID', 'captionID'], axis=1)

    with (inferred_dir/f'{args.gen_model}_test.jsonl').open('w') as f:
        f.write(df.to_json(orient='records', lines=True))



