#-------------------------------------------------------------------------
#Parser.py -- ATG file parser
#Compiler Generator Coco/R,
#Copyright (c) 1990, 2004 Hanspeter Moessenboeck, University of Linz
#extended by M. Loeberbauer & A. Woess, Univ. of Linz
#ported from Java to Python by Ronald Longo
#Parser.frame modified for use as incremental model checker by Sean Sedwards
#
#This program is free software; you can redistribute it and/or modify it
#under the terms of the GNU General Public License as published by the
#Free Software Foundation; either version 2, or (at your option) any
#later version.
#
#This program is distributed in the hope that it will be useful, but
#WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
#or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
#for more details.
#
#You should have received a copy of the GNU General Public License along
#with this program; if not, write to the Free Software Foundation, Inc.,
#59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
#As an exception, it is allowed to write an extension of Coco/R that is
#used as a plugin in non-free software.
#
#If not otherwise stated, any source code generated by Coco/R (other than
#Coco/R itself) does not fall under the GNU General Public License.
#-------------------------------------------------------------------------*/

"""
These classes were taken from github.com/ashishgaurav13/wm2
"""

import sys

from tools.logic import Token, Scanner, Position


class ErrorRec(object):
    def __init__(self, l, c, s):
        self.line = l
        self.col = c
        self.num = 0
        self.str = s


class Errors(object):
    # errMsgFormat = "file %(file)s : (%(line)d, %(col)d) %(text)s\n"
    # Sean removed "file" from message format
    errMsgFormat = "%(file)s : (%(line)d, %(col)d) %(text)s\n"
    eof = False
    count = 0  # number of errors detected
    fileName = ''
    listName = ''
    mergeErrors = False
    mergedList = None  # PrintWriter
    errors = []
    minErrDist = 2
    errDist = minErrDist
    # A function with prototype: f( errorNum=None ) where errorNum is a
    # predefined error number.  f returns a tuple, ( line, column, message )
    # such that line and column refer to the location in the
    # source file most recently parsed.  message is the error
    # message corresponging to errorNum.

    @staticmethod
    def Init(fn, dir, merge, getParsingPos, errorMessages):
        Errors.theErrors = []
        Errors.getParsingPos = getParsingPos
        Errors.errorMessages = errorMessages
        Errors.fileName = fn
        listName = dir + 'listing.txt'
        Errors.mergeErrors = merge
        if Errors.mergeErrors:
            try:
                Errors.mergedList = open(listName, 'w')
            except IOError:
                raise RuntimeError('-- Compiler Error: could not open ' +
                                   listName)

    @staticmethod
    def storeError(line, col, s):
        if Errors.mergeErrors:
            Errors.errors.append(ErrorRec(line, col, s))
        else:
            Errors.printMsg(Errors.fileName, line, col, s)

    @staticmethod
    def SynErr(errNum, errPos=None):
        line, col = errPos if errPos else Errors.getParsingPos()
        msg = Errors.errorMessages[errNum]
        Errors.storeError(line, col, msg)
        Errors.count += 1
        # Sean added exit
        # sys.exit(1)

    @staticmethod
    def SemErr(errMsg, errPos=None):
        line, col = errPos if errPos else Errors.getParsingPos()
        Errors.storeError(line, col, errMsg)
        Errors.count += 1
        # Sean added exit
        # sys.exit(1)

    @staticmethod
    def Warn(errMsg, errPos=None):
        line, col = errPos if errPos else Errors.getParsingPos()
        Errors.storeError(line, col, errMsg)

    @staticmethod
    def Exception(errMsg):
        print(errMsg)
        sys.exit(1)

    @staticmethod
    def printMsg(fileName, line, column, msg):
        vals = {'file': fileName, 'line': line, 'col': column, 'text': msg}
        sys.stdout.write(Errors.errMsgFormat % vals)

    @staticmethod
    def display(s, e):
        Errors.mergedList.write('**** ')
        for c in range(1, e.col):
            if s[c - 1] == '\t':
                Errors.mergedList.write('\t')
            else:
                Errors.mergedList.write(' ')
        Errors.mergedList.write('^ ' + e.str + '\n')

    @staticmethod
    def Summarize(sourceBuffer):
        if Errors.mergeErrors:
            # Initialize the line iterator
            srcLineIter = iter(sourceBuffer)
            srcLineStr = srcLineIter.next()
            srcLineNum = 1

            try:
                # Initialize the error iterator
                errIter = iter(Errors.errors)
                errRec = errIter.next()

                # Advance to the source line of the next error
                while srcLineNum < errRec.line:
                    Errors.mergedList.write(
                        '%4d %s\n' % (srcLineNum, srcLineStr))

                    srcLineStr = srcLineIter.next()
                    srcLineNum += 1

                # Write out all errors for the current source line
                while errRec.line == srcLineNum:
                    Errors.display(srcLineStr, errRec)

                    errRec = errIter.next()
            except:
                pass

            # No more errors to report
            try:
                # Advance to end of source file
                while True:
                    Errors.mergedList.write(
                        '%4d %s\n' % (srcLineNum, srcLineStr))

                    srcLineStr = srcLineIter.next()
                    srcLineNum += 1
            except:
                pass

            Errors.mergedList.write('\n')
            Errors.mergedList.write('%d errors detected\n' % Errors.count)
            Errors.mergedList.close()

        sys.stdout.write('%d errors detected\n' % Errors.count)
        if (Errors.count > 0) and Errors.mergeErrors:
            sys.stdout.write('see ' + Errors.listName + '\n')


class Parser(object):
    _EOF = 0
    _proposition = 1
    maxT = 14

    T = True
    x = False
    minErrDist = 2

    UNDEFINED = -1
    FALSE = 0
    UNDECIDED = 1
    TRUE = 2

    def Check_old(self, propscanner, trace):
        """Deprecated method to check an entire trace with a new property
        Scanner.

        Includes SetProperty, which includes ResetProperty.
        """
        self.SetProperty(propscanner)
        for state in trace:
            result = self.CheckIncremental(state)
            if result != Parser.UNDECIDED: break
        return result

    def Check(self, trace):
        """Checks an entire trace w.r.t.

        an existing property Scanner. Includes ResetProperty, but not
        SetProperty.
        """
        self.ResetProperty()
        for state in trace:
            result = self.CheckIncremental(state)
            if result != Parser.UNDECIDED: break
        return result

    def SetProperty(self, propscanner):
        """Sets the property Scanner that tokenizes the property."""
        self.scanner = propscanner
        self.ResetProperty()

    def ResetProperty(self):
        """Re-iniitializes an existing property Scanner."""
        self.step = 0
        self.maxfactor = -1
        self.start = self.scanner.t = self.scanner.tokens
        self.trace = [] # ADDED

    def CheckIncremental(self, state):
        """Checks a new state w.r.t.

        an existing property Scanner. Constructs a trace from new states
        using a new or previous trace list. If a previous trace list is
        used, states after index self.step are not valid.
        """
        # print('CheckIncremental: %s, step %d, maxfactor %d' % (self.trace, self.step, self.maxfactor))
        if len(self.trace) <= self.step:
            self.trace.append(state)
        else:
            self.trace[self.step] = state
        if self.debug:
            print('CheckIncremental: %s, step %d, maxfactor %d' % (self.trace, self.step, self.maxfactor))
        result = Parser.UNDECIDED
        while self.step > self.maxfactor and result == Parser.UNDECIDED:
            self.scanner.t = self.start
            self.la = self.start
            self.Get()
            result = self.Property(0)
            self.Expect(0)
            # print ("step [{}] checked state [{}]".format(self.step, self.maxfactor))
        self.step += 1
        if self.debug: print('')
        return result

    def __init__(self, debug = False):
        self.scanner = None
        self.token = None  # last recognized token
        self.la = None  # lookahead token
        self.genScanner = False
        self.tokenString = ''  # used in declarations of literal tokens
        self.noString = '-none-'  # used in declarations of literal tokens
        self.errDist = Parser.minErrDist
        self.trace = []
        self.debug = debug

    def getParsingPos(self):
        return self.la.line, self.la.col

    def SynErr(self, errNum):
        if self.errDist >= Parser.minErrDist:
            Errors.SynErr(errNum)

        self.errDist = 0

    def SemErr(self, msg):
        if self.errDist >= Parser.minErrDist:
            Errors.SemErr(msg)

        self.errDist = 0

    def Warning(self, msg):
        if self.errDist >= Parser.minErrDist:
            Errors.Warn(msg)

        self.errDist = 0

    def Successful(self):
        return Errors.count == 0

    def LexString(self):
        return self.token.val

    def LookAheadString(self):
        return self.la.val

    def Get(self):
        while True:
            self.token = self.la
            # Sean added print
            if self.debug: print(self.la)
            self.la = self.scanner.Scan()
            if self.la.kind <= Parser.maxT:
                self.errDist += 1
                break

            self.la = self.token

    def Expect(self, n):
        if self.la.kind == n:
            self.Get()
        else:
            self.SynErr(n)

    def StartOf(self, s):
        return self.set[s][self.la.kind]

    def ExpectWeak(self, n, follow):
        if self.la.kind == n:
            self.Get()
        else:
            self.SynErr(n)
            while not self.StartOf(follow):
                self.Get()

    def WeakSeparator(self, n, syFol, repFol):
        s = [False for i in range(Parser.maxT + 1)]
        if self.la.kind == n:
            self.Get()
            return True
        elif self.StartOf(repFol):
            return False
        else:
            for i in range(Parser.maxT):
                s[i] = self.set[syFol][i] or self.set[repFol][i] or self.set[0][i]
            self.SynErr(n)
            while not s[self.la.kind]:
                self.Get()
            return self.StartOf(syFol)

    def PyCheck(self):
        """Unused dummy method."""
        result = self.Property(0)
        return result

    def Property(self, index):
        """Main property entry point."""
        start = self.la
        if self.step == 0:
            start.result = Parser.UNDECIDED
            start.index = 0
        if index > start.index:
            start.index = index
            start.result = Parser.UNDECIDED
        if self.StartOf(1):
            result = self.Implication(start.index if start.until else index)
            if self.step == 0: start.result = result
            if (self.la.kind == 2):
                if result == Parser.UNDECIDED:
                    self.SemErr("first argument of U is UNDECIDED")
                self.Get()
                start.until = True  # label previous Implication as part of Until
                r2 = self.Implication(start.index)
                if start.result == Parser.UNDECIDED:  # Until not yet FALSE of TRUE
                    if r2 == Parser.TRUE:
                        result = start.result = Parser.TRUE
                    elif result == Parser.FALSE:
                        start.result = result
                    else:
                        result = Parser.UNDECIDED
                        start.index += 1
                else:  # Until is already decided, so result is unchanged
                    result = start.result
        elif self.la.kind == 3:
            self.Get()
            result = self.Implication(start.index)
            if start.result == Parser.UNDECIDED:  # Finally not yet FALSE or TRUE
                if result == Parser.TRUE:  # sub-formula is decided TRUE
                    start.result = result  # Finally will remain TRUE
                elif result == Parser.FALSE:  # sub-formula is decided FALSE,
                    result = Parser.UNDECIDED  # but Finally is not decided,
                    start.index += 1  # so check next sub-formula
                else:  # sub-formula is UNDECIDED,
                    result = Parser.UNDECIDED  # so Finally is not decided
                    start.index = start.index  # send the same index to sub-formula
            else:  # Finally is already decided, so result is unchanged
                result = start.result
        elif self.la.kind == 4:
            self.Get()
            result = self.Implication(start.index)
            if start.result == Parser.UNDECIDED:  # Globally not yet FALSE or TRUE
                if result == Parser.FALSE:  # sub-formula is decided FALSE
                    start.result = Parser.FALSE  # Globally will remain FALSE
                elif result == Parser.TRUE:  # sub-formula is decided TRUE,
                    result = Parser.UNDECIDED  # but Globally is not decided
                    start.index += 1  # check the next sub-formula
                else:  # result is UNDECIDED, so sub-formula is not yet resolved
                    result = Parser.UNDECIDED  # and Globally is also not resolved
                    start.index = start.index  # send the same index to sub-formula
            else:  # Globally is already decided, so result is unchanged
                result = start.result
        elif self.la.kind == 5:
            self.Get()
            result = self.Implication(start.index)
            if start.index == index + 1:
                if result == Parser.UNDECIDED:
                    result = Parser.UNDECIDED
                    start.index = start.index
                else:
                    start.result = result
                    start.index += 1
            elif start.index == index:
                start.index += 1
                result = Parser.UNDECIDED
            elif index > start.index:
                self.SemErr(
                    "X has lost synchronization: index = {}, start.index = {}".
                    format(index, start.index))
        else:
            self.SynErr(15)
        return result

    def Implication(self, index):
        result = self.Disjunction(index)
        if (self.la.kind == 6):
            self.Get()
            r2 = self.Disjunction(index)
            if result == Parser.FALSE or r2 == Parser.TRUE:
                result = Parser.TRUE
            elif result == Parser.TRUE and r2 == Parser.FALSE:
                result = Parser.FALSE
            elif result == Parser.UNDECIDED or r2 == Parser.UNDECIDED:
                result = Parser.UNDECIDED
            else:
                self.SemErr("unrecognized argument")
        return result

    def Disjunction(self, index):
        result = self.Conjunction(index)
        while self.la.kind == 7:
            self.Get()
            r2 = self.Conjunction(index)
            if result == Parser.TRUE or r2 == Parser.TRUE: result = Parser.TRUE
            elif result == Parser.UNDECIDED or r2 == Parser.UNDECIDED:
                result = Parser.UNDECIDED
            else:
                result = Parser.FALSE

        return result

    def Conjunction(self, index):
        result = self.Factor(index)
        while self.la.kind == 8:
            self.Get()
            r2 = self.Factor(index)
            if result == Parser.FALSE or r2 == Parser.FALSE:
                result = Parser.FALSE
            elif result == Parser.UNDECIDED or r2 == Parser.UNDECIDED:
                result = Parser.UNDECIDED
            else:
                result = Parser.TRUE

        return result

    def Factor(self, index):
        if index > self.maxfactor:
            self.maxfactor = index
        neg = False
        result = Parser.UNDEFINED
        if (self.la.kind == 9):
            self.Get()
            neg = True
        if self.la.kind == 1:
            if self.la.val not in self.APdict:
                self.SemErr("unrecognized proposition")
            self.Get()
            # print('Trace: %s, index: %d' % (self.trace, index))
            # ADDED
            # self.trace[index] => self.trace[self.step]
            if self.trace[self.step] & (1 << self.APdict[self.token.val]):
                result = Parser.TRUE
            else:
                result = Parser.FALSE
        elif self.la.kind == 10:
            self.Get()
            result = Parser.TRUE
        elif self.la.kind == 11:
            self.Get()
            result = Parser.FALSE
        elif self.la.kind == 12:
            self.Get()
            result = self.Property(index)
            self.Expect(13)
        else:
            self.SynErr(16)
        if neg: result = Parser.TRUE - result
        return result

    def Parse(self, scanner):
        self.scanner = scanner
        self.la = Token()
        self.la.val = u''
        self.Get()
        self.PyCheck()
        self.Expect(0)

    set = [[T, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x],
           [x, T, x, x, x, x, x, x, x, T, T, T, T, x, x, x]]

    errorMessages = {
        0: "EOF expected",
        1: "proposition expected",
        2: "\"U\" expected",
        3: "\"F\" expected",
        4: "\"G\" expected",
        5: "\"X\" expected",
        6: "\"=>\" expected",
        7: "\"or\" expected",
        8: "\"and\" expected",
        9: "\"not\" expected",
        10: "\"true\" expected",
        11: "\"false\" expected",
        12: "\"(\" expected",
        13: "\")\" expected",
        14: "??? expected",
        15: "invalid Property",
        16: "invalid Factor",
    }
