from io import StringIO
import re

from io import StringIO
from pysmt.smtlib.parser import SmtLibParser, SmtLibScript, SmtLibCommand
from pysmt.environment import reset_env
from pysmt.exceptions import PysmtSyntaxError, PysmtTypeError, UnsupportedOperatorError, PysmtInfinitesimalError
from pysmt.shortcuts import Symbol, Int, Real, NotEquals, Equals, Plus, Minus, Times, Div, ToReal, ToInt, to_smtlib
from pysmt.typing import INT, REAL

from MathGym.utils import gpt, solve


class simplifier:
    def __init__(self, comment):
        """ 
        comment: whether to add comment
        """
        self.id_counter = 0
        self.vars = dict()
        self.aux_vars = []
        self.simplified_script = SmtLibScript()
        self.comment = comment

    def fresh(self, statement):
        """ re-define the var type in the declare """
        for s in reversed(statement.split('\n')):
            s = s.strip()
            if "get-value" in s and not s.startswith(';'):
                self.get_value = s
        if hasattr(self, 'get_value') == False: # add this because get-value could be simplified in some case
            return statement
        tmp_statement = statement.split(self.get_value, 1)[0] + "\n(get-model)"
        ok, sols = solve.solve(tmp_statement)
        if ok == False or sols[0] == "": return statement
        res = {}  
        for s in sols:  
            key, value = s.split(' := ', 1)  
            res[key] = value 
        smt_parser = SmtLibParser()
        cmds = smt_parser.get_command_generator(StringIO(tmp_statement))
        # refresh a new parser and reset the var
        smt_parser = SmtLibParser()
        for cmd in cmds:
            # Note: cmd.args[0] is fnode
            if cmd.name == "declare-fun":
                tmp_name, tmp_type = cmd.args[0].symbol_name(), cmd.args[0].symbol_type()
                if tmp_name in res: # may exist duplicate definition
                    val = res[tmp_name].replace("^", "**")
                    try:
                        diff = eval(f"abs({val} - int({val}))")
                    except (TypeError, SyntaxError, NameError, ValueError, OverflowError) as e:
                        diff = 1e5
                    if diff <= 1e-5 and tmp_type == REAL:
                        tmp_type = INT
                        cmd = SmtLibCommand(name="declare-fun", args=(Symbol(tmp_name, tmp_type),))
                    self.vars[tmp_name] = {'name': tmp_name, 'type': tmp_type}
                    self.simplified_script.add_command(cmd) 
            elif cmd.name == "declare-const": # back compatible to smtlib 1.0
                self.simplified_script.add("declare-fun", cmd.args)
            elif cmd.name == "define-fun":
                if cmd.args[1] == []: # this is a variable
                    var = Symbol(cmd.args[0], cmd.args[2])
                    self.simplified_script.add_command(SmtLibCommand(name="declare-fun", args=(var,)))
                    cmd = SmtLibCommand(name="assert", args=(Equals(var, cmd.args[3]),))
                self.simplified_script.add_command(cmd)
            elif cmd.name == "get-model": # ignore get-model
                continue
            else:
                self.simplified_script.add_command(cmd)
        buf_out = StringIO()
        self.simplified_script.serialize(buf_out, daggify=False)
        statement = buf_out.getvalue() + self.get_value
        return statement

    def simplify(self, statement):
        """ simplify each equation in the declare """
        for s in reversed(statement.split('\n')): # parse the last get-value
            s = s.strip()
            if "get-value" in s and not s.startswith(';'):
                self.get_value = s
                break
        if hasattr(self, 'get_value') == False: # add this because get-value could be simplified in some case
            return statement
        smt_parser = SmtLibParser()
        cmds = smt_parser.get_command_generator(StringIO(statement))
        # refresh a new parser and reset the var
        for cmd in cmds:
            # Note: cmd.args[0] is fnode
            if cmd.name == "declare-const":
                self.simplified_script.add("declare-fun", cmd.args)
            elif cmd.name == "assert":
                fnode = cmd.args[0].simplify()
                cmd = SmtLibCommand(name="assert", args=(fnode,))
                self.simplified_script.add_command(cmd)
            elif cmd.name == "get-value": # ignore get-value
                continue
            else:
                self.simplified_script.add_command(cmd)
        buf_out = StringIO()
        self.simplified_script.serialize(buf_out, comment=self.comment, daggify=False)
        statement = buf_out.getvalue() + self.get_value
        return statement

    def rename(self, statement):
        """ rename the variables to x1, x2,... """
        for s in reversed(statement.split('\n')):
            s = s.strip()
            if "get-value" in s and not s.startswith(';'):
                self.get_value = s
        if hasattr(self, 'get_value') == False: # add this because get-value could be simplified in some case
            return statement
        mappings = {}
        var_counter = 0
        tmp_statement = statement.split(self.get_value, 1)[0]
        smt_parser = SmtLibParser()
        cmds = smt_parser.get_command_generator(StringIO(tmp_statement))
        # refresh a new parser and reset the var
        for cmd in cmds:
            # Note: cmd.args[0] is fnode
            if cmd.name == "declare-fun" or cmd.name == "declare-const":
                origin_name, tmp_type = cmd.args[0].symbol_name(), cmd.args[0].symbol_type()
                tmp_name = f"x_{var_counter}"
                cmd = SmtLibCommand(name="declare-fun", args=(Symbol(tmp_name, tmp_type),))
                self.vars[tmp_name] = {'name': tmp_name, 'type': tmp_type}
                mappings[origin_name] = tmp_name 
                self.simplified_script.add_command(cmd) 
                var_counter += 1
            elif cmd.name == "get-model": # ignore get-model
                continue
            else:
                self.simplified_script.add_command(cmd)
        buf_out = StringIO()
        self.simplified_script.serialize(buf_out, comment=self.comment, daggify=False)
        statement = buf_out.getvalue() + self.get_value
        for old, new in mappings.items():  
            pattern = re.compile(f"(?<![a-zA-Z_]){old}(?![a-zA-Z_])")  
            statement = pattern.sub(new, statement)
        return statement

    
def refresh(statement, comment):
    error_tuple = (PysmtSyntaxError, NotImplementedError, AttributeError, UnsupportedOperatorError, PysmtTypeError, ValueError)
    try:
        s = simplifier(comment=comment)
        refresh_statement = s.fresh(statement)
        s = simplifier(comment=comment) # reset
        simp_statement = s.simplify(refresh_statement)
    except error_tuple as e: # refresh may occur type mistmatching
        return statement
    return simp_statement

def simplify(statement, comment):
    s = simplifier(comment=comment)
    return s.simplify(statement)

def rename(statement):
    s = simplifier(comment=False)
    return s.rename(statement)


