""" Calculate the entailment requirement accuracy from the backoff answers. """


import click
import ujson as json
import transformers


@click.command()
@click.option("--input-path", type=click.Path(exists=True), help="Path to the input data.", required=True)
def main(
    input_path,
):
    """ """
    
    pipeline = transformers.pipeline("text-classification",  model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli", device=0, batch_size=5)
    
    with open(input_path, 'r', encoding='utf-8') as file_:
        data = json.load(file_)
        
    pos = []
    neg = []
    results = []

    for item in data:
        hypothesis = item['backoff_claims']['b-5']
        premises_pos = item['claims'][:5]
        premises_neg = item['claims'][5:]

        pos_results = len([r for r in pipeline([{"text": pre, "text_pair": hypothesis} for pre in premises_pos]) if r['label'] == 'entailment'])
        neg_results = len([r for r in pipeline([{"text": pre, "text_pair": hypothesis} for pre in premises_neg]) if r['label'] == 'contradiction'])

        pos.append(pos_results / 5)
        neg.append(neg_results / 5)
        results.append((pos_results + 1) / 6 * (neg_results + 1) / 6)

    print("pos: %.2f" % (sum(pos) / len(pos)))
    print("neg: %.2f" % (sum(neg) / len(neg)))
    print("p*n: %.2f" % (sum(results) / len(results)))
        
        
if __name__ == '__main__':
    main()