from pathlib import Path
import json
import random
import util
from data import snli_labels, snli_keys_no_parses

def find_prompt_examples(data):
    # the example was validated by at least 5 annotators
    criteria1 = lambda x: len(x['annotator_labels']) >= 5 
    # all the annotators agreed
    criteria2 = lambda x: all(l == x['gold_label'] for l in x['annotator_labels'])
    both_criteria = lambda x: criteria1(x) and criteria2(x)
    perfect_items = [item for item in data if both_criteria(item)]
    grouped = util.group_records_by(perfect_items, 'captionID')
    examples = []
    for group in grouped:
        if set(item['gold_label'] for item in group) == snli_labels:
            examples.append(group)
    return examples

def sample_examples(grouped_examples, n):
    groups = random.sample(grouped_examples, n)
    examples = []
    for group in groups :
        ex = {
                'captionID': group[0]['captionID'],
                'sentence1': group[0]['sentence1']
        }
        for label in snli_labels:
            label_items = [item for item in group if item['gold_label'] == label]
            item = random.choice(label_items)
            ex[label] = {
                    'pairID': item['pairID'],
                    'sentence2': item['sentence2']
            }
        examples.append(ex)
    return examples


if __name__ == '__main__':

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('command', type=str, choices=['perfect-examples', 'random-examples'])
    args = parser.parse_args()

    random.seed(42)

    match args.command:

        case 'perfect-examples':
            data_dir = Path('./data/snli_1.0')
            train_data_path = data_dir/'snli_1.0_train.jsonl'
            example_path = Path('./prompts/perfect-snli-examples.json')
            data = util.load_jsonl(train_data_path, keys=snli_keys_no_parses)
            grouped_examples = find_prompt_examples(data)
            examples = sample_examples(grouped_examples, 10)
            if not example_path.exists():
                with example_path.open('w') as f:
                    json.dump(examples, f)

        case 'random-examples':
            data_dir = Path('./data/snli_1.0')
            train_data_path = data_dir/'snli_1.0_train.jsonl'
            example_path = Path('./prompts/random-snli-examples.json')
            data = util.load_jsonl(train_data_path, keys=snli_keys_no_parses)
            examples = random.sample(data, 10)
            if not example_path.exists():
                with example_path.open('w') as f:
                    for item in examples:
                        util.write_jsonl(item, f)
