from __future__ import annotations
from antlr4 import *

from rlang.knowledge import RLangKnowledge
from rlang.language.RLangLexer import RLangLexer
from rlang.language.RLangParser import RLangParser
from rlang.language.RLangErrorListener import RLangErrorListener
from rlang.language.RLangListener import RLangListener

import os, sys

class HiddenErrors:
    def __enter__(self):
        self._original_stderr = sys.stderr
        sys.stderr = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stderr.close()
        sys.stderr = self._original_stderr



def reformat(line):
    rlang = ''
    n_tabs = -1
    indents = line.split('>>')
    for ind in indents:
        n_tabs += 1
        dedents = ind.split('<<')
        for ded in dedents:
            rlang += '\n' + n_tabs * '\t' + ded
            n_tabs += -1
        n_tabs += 1
    
    return rlang

def parse_file(rlang_fname: str):
    with open(rlang_fname, 'r') as f:
        lines = f.read().split('\n')
    score = 0
    print(len(lines))
    for line in lines:    
        ref = reformat(line)
        if parse(ref):
            score += 1
        else:
            print(line)
    return score / len(lines)

def parse(rlang):
    pass
    """Parses an rlang string into an :py:class:`.RLangKnowledge` object.

    Args:
        rlang: string containing rlang
        prior_knowledge: prior knowledge that should be retained after parsing

    """
    prior_knowledge = RLangKnowledge()
    rlang = InputStream(rlang)
    lexer = RLangLexer(rlang)
    stream = CommonTokenStream(lexer)
    parser = RLangParser(stream)
    parser.addErrorListener(RLangErrorListener())
    with HiddenErrors():
        try:
            tree = parser.program()
        except:
            return False

    return True

def exact_match(preds, labels):
    n = len(preds)
    score = sum([preds[i] == labels[i] for i in range(n)])
    return score / n

def rate_functional_accuracy(preds, labels, filenames=False):
    if filenames:
        with open(preds, 'r') as f:
            preds = f.readlines()
        with open(labels, 'r') as f:
            labels = f.readlines()
    n = len(preds)
    score = 0
    for i in range(n):
        query = 'Target: {}\nPrediction: {}'.format(labels[i], preds[i])
        response = input(query)
        if response.lower() == 't':
            score += 1

    return score / n