from collections import defaultdict
import os
import sys
import queue
import time
script_dir = os.path.dirname( __file__ )
sys.path.append(script_dir)
#USAGE: file containing the structure representing an Extended type instance
#       Extended types have two main subclasses: token types and business types
#       Token types follow the default token structure of:
#           - numerator type(s)
#           - denominator type(s)
#           - normalization amt
#           - linked contract (for addresses)
#           - fields          (for objects)
#       Business types (usually one) are categorical, i.e. (fee, balance, ...) and their relations will be defined here

from tcheck_parser import f_type_name, f_type_num, update_start
from tcheck_help import eprint, extract_operations, copy_queue, view_queue, printFull, convert_int_ir, get_search_count, extract_operations_source, IRHLCReturned


class SourceTraceTree():
    def __init__(self):
        self._source_exp = None
        self._isRoot = False
        self._parent = None
        self._children = []
        self._allBranches = []
        self._source_mapping = None

    @property
    def source_exp(self):
        return self._source_exp
    
    @source_exp.setter
    def source_exp(self, exp):
        self._source_exp = exp

    @property
    def source_mapping(self):
        return self._source_mapping
    
    @source_mapping.setter
    def source_mapping(self, mapping):
        self._source_mapping = mapping
    
    @property
    def isRoot(self):
        return self._isRoot
    
    @isRoot.setter
    def isRoot(self, value):
        self._isRoot = value
    
    @property
    def parent(self):
        return self._parent
    
    @parent.setter
    def parent(self, parent):
        self._parent = parent
    
    @property
    def children(self):
        return self._children
    
    def add_child(self, child):
        self._children.append(child)
        child.parent = self

    @property
    def allBranches(self):
        return self._allBranches
    
    @property
    def isLeaf(self):
        if(len(self._children) == 0):
            return True
        return False


    def __str__(self):
        #Method to convert to string
        output_str = ""
        if(self.isLeaf):
            return(str(self._source_exp) + " = LLM_Set\n")

        if(printFull):
            output_str = str(self._source_exp) + "\n"
            for child in self._children:
                output_str+=str(child)
            return output_str
        else:
            output_str = str(self._source_exp + "\n")
            for child in self._children:
                str(child)
                childBranches = child._allBranches
                for branch in childBranches:
                    self._allBranches.append(output_str + branch)
            return None






#Deprecated
class ExpressionTrace():
    def __init__(self):
        self._possible_expressions = []
        self._expression = None
        self._contract = None
        self._function = None
        self._relevant_ir = []
        self._state = False

    def add_contract(self, contract):
        self._contract = contract

    def add_function(self, function):
        self._function = function

    def add_relevant_ir(self, ir):
        self._relevant_ir.append(ir)

    def set_expression(self, expression):
        #Used when there is only one expression
        self._expression = expression
        self._state = True

    def set_state(self):
        #Used when a choice must be made between many expressions. Not preferrable
        if(self._state):
            eprint("Already has valid expression: {}".format(self._expression))
            return
        self._state = True
        if(len(self._possible_expressions) > 0):
            self._expression = self._possible_expressions[0]
        else:
            eprint("Bad set...")

    def set_possible_expressions(self, possible_expressions):
        self._possible_expressions = possible_expressions

    def get_variable_names(self) -> []:
        names = []
        for ir in self._relevant_ir:
            if(ir.name != None):
                names.append(ir.name)
        return names

    @property
    def expression(self):
        return self._expression

    @property
    def contract(self):
        return self._contract

    @property
    def function(self):
        return self._function

    @property
    def relevant_ir(self):
        return self._relevant_ir

    @property
    def state(self):
        return self._state
    
    def __str__(self):
        if(self._expression == None):
            eprint("None, something is wrong")
            return("None, not set!")
            #exit(1)
        if(self._state):
            return(self._expression)
        return("Maybe " + self._expression)

class ExtendedType():
    def __init__(self):
        #Initialized an 'undefined' data type
        self._name = None
        self._function_name = None
        #TODO add the function ir
        self._contract_name = None
        #Token type
        self._num_token_types = []
        self._den_token_types = []
        self._base_decimals = 0
        self._address = 'u'
        self._norm = 'u'
        self._value = 'u'
        #All references have an assumed type: once they are set on the lhs, they are locked.
        self._reference_locked = False
        #TODO address lf should automatically be set to its name
        self._linked_contract = None
        self._fields = []
        self._reference_root = None
        self._reference_field = None
        self._trace = None

        self._sourceTraceTree = None
        #Business type
        self._finance_type = -1
        #SINGLE Parent node which propogated its type to this node/Node of operation which generated the finance_type
        self._finance_type_parent = None
        self._parent_mappings = {}
        self._finance_step_parents = []
        self._ignore_llm = False

        #Increments based on the error handling attempt
        self._searched_id = -1
        # IR Operation that produced this node from the parent
        self._finance_type_ir = None
        #ALL Children node which directly receive the finance type of this variable
        self._finance_type_children = []
        self._updated = False

        #Flag for if the type was directly set by the LLM (parameters and global variables)
        self._llm_set = False
        #Flag for LLm remedy {1, 2}
        self._llm_remedy = [False, False]

    #Getters and setters for the fields
    @property
    def name(self):
        return(self._name)

    @name.setter
    def name(self, sname):
        self._name = sname

    @property
    def ignore_llm(self):
        return(self._ignore_llm)
    
    @ignore_llm.setter
    def ignore_llm(self, bol):
        self._ignore_llm = bol

    def set_remedy(self, remedy):
        if(remedy != 1 and remedy != 2):
            return
        self._llm_remedy[remedy-1] = True

    @property
    def searched_id(self):
        return(self._searched_id)
    
    @searched_id.setter
    def searched_id(self, counter):
        self._searched_id = counter

    @property
    def llm_remedy(self):
        return(self._llm_remedy)

    @property
    def ref_root(self):
        return(self._reference_root)

    @property
    def ref_field(self):
        return(self._reference_field)
    
    def ref(self, ref):
        self._reference_root = ref[0]
        self._reference_field = ref[1]

    @property
    def llm_set(self):
        return (self._llm_set)

    @llm_set.setter
    def llm_set(self, _llm_set):
        self._llm_set = _llm_set

    @property
    def address(self):
        return(self._address)

    @address.setter
    def address(self, x):
        self._address = x

    @property
    def function_name(self):
        return(self._function_name)

    @function_name.setter
    def function_name(self, fname):
        self._function_name = fname

    @property
    def value(self):
        return(self._value)

    @value.setter
    def value(self, value):
        if(value == None):
            value = 'u'
        self._value = value

    @property
    def contract_name(self):
        return(self._contract_name)
    
    @contract_name.setter
    def contract_name(self, cname):
        self._contract_name = cname
    
    #DEPRECATED
    def resolve_trace(self, trace_labels):
        for n in self._num_token_types:
            if n in trace_labels:
                n = trace_labels[n]
        for d in self._den_token_types:
            if d in trace_labels:
                d = trace_labels[d]

    def resolve_labels(self, label_sets):
        for n in self._num_token_types:
            if n in label_sets:
                n = label_sets[n].head
        for d in self._den_token_types:
            if d in label_sets:
                d = label_sets[d].head
    
    @property
    def num_token_types(self):
        return(self._num_token_types)

    def add_num_token_type(self, token_type):
        if(token_type == -1):
            if(len(self._num_token_types) != 0 or token_type in self._num_token_types):
                return
            self._num_token_types.append(token_type)
        else:
            if(token_type in self._den_token_types):
                self._den_token_types.remove(token_type)
            else:
                if(-1 in self._num_token_types):
                    self._num_token_types.remove(-1)
                self._num_token_types.append(token_type)


    #USAGE: given an ir, returns a list of all parents contributing to the current ir (potential to rerun)
    def get_all_finance_parents(self, contract_pointer = None, function_metadata = None, maxTrace = False):

        self.all_expression_traces = []
        sourceTraceTree = SourceTraceTree()
        sourceTraceTree.isRoot = True

        expressionTraces = queue.Queue()
        #The expressiontrace generated is temporary null
        expressionTraces.put(ExpressionTrace())
        self.all_expression_traces.append(expressionTraces)
        #TODO don't pass contract pointer and function_metadata, simply create new fields
        results = self.remove_dup_parents(self.search_single_finance_line(self, contract_pointer, function_metadata, expressionTraces, sourceTraceTree, maxTrace))
        self._sourceTraceTree = sourceTraceTree
        eprint("Listing...")
        for r in results:
            eprint(r)
            with open("finance_errors.txt", "a") as file:
                file.write("Ancestor: {} in function {}\n".format(r.name, r.function_name))

            
        #eprint(len(self.all_expression_traces))
        #eprint(expressionTraces.qsize())
        #view_queue(expressionTraces)
        all_related_exp = set()
        for branch in self.all_expression_traces:
            eprint("Branch here")
            all_related_exp.update(view_queue(branch))

        eprint(all_related_exp)
        
                #file.write("All related exp: {}\n".format(exp.__str__()))
        eprint("****")
        eprint(self)
        #for ancestor in results:
            #eprint(result)
        return results

    #USAGE: removes the duplicates in the fiance parent results:
    def remove_dup_parents(self, results):
        seen = []
        if(results == None):
            return seen
        for parent in results:
            if(parent in seen):
                continue
            seen.append(parent)
        return seen

    #USAGE: searches a single line of ancestry given a node's parents
    def search_single_finance_line(self, x, contract_pointer, function_metadata, expressionTraces, sourceTraceTree, maxTrace = False):

        #Deprecated, Maximum Recursion Depth Exceeded

        #'x' is an extok
        #eprint(x.finance_type_parent)
        eprint("----------")

        #TODO filter out names s.a. bound_ ...

        #eprint(x.name)
        eprint(x)

        with open("temp_analysis.txt", "a") as file:
            file.write("_____________________\n")
            file.write(str(x) + '\n')

        global_searched_id = get_search_count()

        if(x._searched_id == global_searched_id):
            eprint("Caught loop: {}\n".format(global_searched_id))
            eprint("Parent tuple: {}\n".format(x._parent_mappings))
            eprint("Step parents: {}\n".format(x._finance_step_parents))
            return(None)
            exit(1)
        x._searched_id = global_searched_id
        #eprint(x.finance_type_parent)
        #eprint(x.function_name)
        #eprint(expressionTraces.qsize())
        #eprint(f"COntract pointer: {contract_pointer}")
        #eprint(f"Function metadata: {function_metadata}")
        try:
            eprint(x.finance_type_ir)
            converted = convert_int_ir(x.finance_type_ir)

            with open("temp_analysis.txt", "a") as file:
                file.write(str(x.finance_type_ir) + '\n')
                #file.write("{}\n".format(converted))
                

            if converted is None:
                eprint("Source: {}".format(x.finance_type_ir.expression))
                converted = str(x.finance_type_ir.expression)
                sourceTraceTree.source_mapping = x.finance_type_ir.expression.source_mapping
                '''
                eprint(x.finance_type_ir.expression.type)
                try: 
                    eprint(x.finance_type_ir.expression.source_mapping)
                except Exception as e:
                    eprint("Bad source: {}.format(e)")
                exit(1)
                '''

            else:
                append_name = None
                if converted == IRHLCReturned:
                    #Don't contiue this ancestor
                    #converted = "HLC_Return: "
                    return([])
                if(sourceTraceTree.isRoot is False and x.name in sourceTraceTree.parent.source_exp):
                    append_name = x.name
                else:
                    append_name = x.name.rsplit('_')[0]

                converted += ': ' + append_name + ' in ' + (x.function_name if (f := x.function_name) is not None else "?")

            with open("temp_analysis.txt", "a") as file:
                file.write("Source type: {}\n".format(type(x.finance_type_ir)))
                file.write("Source: {}\n".format(converted))
            eprint("Source: {}".format(converted))
            #if(sourceTraceTree.isRoot):
            sourceTraceTree.source_exp = converted

            #Check string source
            if(function_metadata is not None and x.function_name in function_metadata):
                function_data = function_metadata[x.function_name][0] #[func_file_name,function, func_start_line, func_end_line]
                
                file_name, function_pointer, startln, endln = function_data

                possible_exps = extract_operations_source(converted, file_name, startln, endln)

                with open("temp_analysis.txt", "a") as file:
                    file.write("Possible ops: ({})\n".format(possible_exps))
                    if(len(possible_exps) == 1):
                        #sourceTraceTree.source_exp = possible_exps[0]
                        pass

        except Exception as e:
            eprint("Could not print source: {}".format(e))

            #From experience, this seems to be a sign of llm set



        '''if(isinstance(x.finance_type_parent, tuple)):
            eprint(x.finance_type_parent[0])
            eprint(x.finance_type_parent[1])
            eprint("xxxxxx")'''
        #eprint("maxTrace? {}".format(maxTrace))

        #expressionTraces should not be empty
        unfound_exp = expressionTraces.get()
        eprint("unfound_exp exp: {}".format(unfound_exp._expression))
        if(unfound_exp.function == None):
            pass
        elif(str(unfound_exp.function) != x.function_name):
            unfound_exp.set_state()


        if(unfound_exp.state):
            expressionTraces.put(unfound_exp)
            unfound_exp = ExpressionTrace()
            '''
            if(unfound_exp.function != None):
                expressionTraces.put(unfound_exp)
                unfound_exp = ExpressionTrace()
            else:
                #Discard the previous item/reuse it for future assignments
                pass
            '''
            unfound_exp.add_contract(contract_pointer)

        if(x.name and not(x.name.startswith("TMP") or x.name.startswith("REF"))):
            
            if(function_metadata is not None and x.function_name in function_metadata):
                function_data = function_metadata[x.function_name][0] #[func_file_name,function, func_start_line, func_end_line]
                
                file_name, function_pointer, startln, endln = function_data
                unfound_exp.add_function(function_pointer)
                current_names = unfound_exp.get_variable_names()
                current_names.append(x.name)

                found_results = extract_operations(current_names, file_name, startln, endln)
                if(len(found_results) == 0):
                    current_names.pop()
                    current_names.append(x.name.rsplit('_')[0])
                    found_results = extract_operations(current_names, file_name, startln, endln)
                if(len(found_results) == 0):
                    unfound_exp.set_state()
                    #only one variable + none set = must be contract set/parameter
                    expressionTraces.put(unfound_exp)
                    unfound_exp = ExpressionTrace()
                    eprint("[!] Did not find occurance of variable in this function!")
                elif(len(found_results) == 1):
                    #Best case scenario
                    unfound_exp.set_expression(found_results[0])
                else:
                    #eprint("Set set")
                    unfound_exp.set_possible_expressions(found_results)
                unfound_exp.add_relevant_ir(x)

            else:
                eprint("Function name not tracked... [Fix?]")

            if(maxTrace == False):
                return([x])
        
        expressionTraces.put(unfound_exp)

        if(isinstance(x.finance_type_parent, tuple)):
            #eprint("is tuple")
            temp = copy_queue(expressionTraces, False)
            self.all_expression_traces.append(temp)
            source_child_1 = SourceTraceTree()
            source_child_2 = SourceTraceTree()
            sourceTraceTree.add_child(source_child_1)
            sourceTraceTree.add_child(source_child_2)
            child_1_res = self.search_single_finance_line(x.finance_type_parent[0], contract_pointer, function_metadata, expressionTraces, source_child_1, maxTrace)
            child_2_res = self.search_single_finance_line(x.finance_type_parent[1], contract_pointer, function_metadata, temp, source_child_2, maxTrace)
            if child_1_res:
                if child_2_res:
                    return(child_1_res + child_2_res)
                return(child_1_res)
            return (child_2_res)
        else:
            if(x.finance_type_parent == None or x.llm_set != False):
                #Found head ancestor
                #eprint("[*]HEAD:")
                #Max trace is true
                eprint(self.all_expression_traces)
                eprint(unfound_exp._possible_expressions)
                unfound_exp.set_state()
                eprint("Unfound exp: {}".format(unfound_exp.expression))
                try:
                    temp = expressionTraces.get()
                    eprint("et.get(): {}".format(temp.expression))
                    expressionTraces.put(temp)
                except Exception:
                    eprint("COuld not peek")
                return([x])
            else:
                #eprint("[-]CHILD:")
                #eprint(x)
                source_child = SourceTraceTree()
                sourceTraceTree.add_child(source_child)
                return (self.search_single_finance_line(x.finance_type_parent, contract_pointer, function_metadata, expressionTraces, source_child, maxTrace))

    @property
    def finance_type_parent(self):
        return(self._finance_type_parent)

    @property
    def finance_type_ir(self):
        #String or ir
        return(self._finance_type_ir)
    
    @finance_type_ir.setter
    def finance_type_ir(self, val):
        if (not isinstance(val, int)):
            return
        self._finance_type_ir = val

    @finance_type_parent.setter
    def finance_type_parent(self, x):
        #x should be ExtendedType object
        force = False
        ir = None
        if(isinstance(x, tuple)):
            if (x[0] == True):
                force = True
                ir = x[2]
                x = x[1]
                
            else:
                ir = x[1]
                x = x[0]
                
        else:
            eprint("[!] No ir set here")

            #Plug leakage
            exit(1)

        if(not (isinstance(x, ExtendedType) or (isinstance(x, tuple) and len(x) > 0 and isinstance(x[0], ExtendedType)))):
            eprint("Bad assignemnt of extended type")
            eprint(x)
            eprint("------------")
            exit(1)
        if(force == False and self._finance_type_parent != None):
            #Do not allow assignment of parent multiple times
            return
        '''
        if(self.name == "REF_25"):
            eprint("Parent for REF_25")
            eprint(x)
            time.sleep(3)
        '''
        print("Set trace: {}".format(x))
        if(self._finance_type_parent):
            self._finance_step_parents.append(self._finance_type_parent)
            

        self._finance_type_parent = x

        self._parent_mappings[self._finance_type] = x #Save history

        if ir is not None:
            self._finance_type_ir = ir #String or ir

    def get_finance_type_children(self):
        return(self._finance_type_children)

    def add_finance_type_children(self, x):
        self._finance_type_children.append(x)


    def clear_num(self):
        self._num_token_types.clear()

    @property
    def den_token_types(self):
        return(self._den_token_types)

    def add_den_token_type(self, token_type):
        if(token_type == -1):
            if(len(self._den_token_types) != 0 or token_type in self._den_token_types):
                return
            self._den_token_types.append(token_type)
        else:
            if(token_type in self._num_token_types):
                self._num_token_types.remove(token_type)
            else:
                if(-1 in self._den_token_types):
                    self._den_token_types.remove(-1)
                self._den_token_types.append(token_type)
    
    def clear_den(self):
        self._den_token_types.clear()

    @property
    def norm(self):
        return self._norm

    @norm.setter
    def norm(self, a):
        #if(a == -404):
        #    a = '*'
        self._norm = a
    
    @property
    def base_decimals(self):
        return self._base_decimals

    @base_decimals.setter
    def base_decimals(self, a):
        self._base_decimals = a

    def total_decimals(self):
        if(self._norm == "*"):
            return "*"
        else:
            return(self._base_decimals + self._norm)

    @property
    def linked_contract(self):
        return self._linked_contract

    @linked_contract.setter
    def linked_contract(self, a):
        self._linked_contract = a

    @property
    def fields(self):
        return self._fields

    def add_field(self, new_field):
        for field in self._fields:
            if(field.name == new_field.name):
                self._fields.remove(field)
                break
        self._fields.append(new_field)

    def check_field_dupe(self, new_field):
        for field in self._fields:
            if(field.name == new_field.name):
                return field
        return None

    #USAGE: creates a dictionary storing all the fields
    def return_all_fields(self, _preamble = None):
        #preamble format: (name, name, name...)
        #update to save other things? only saves financial meaning
        return_dict = {}
        eprint("{}, {}".format(_preamble, (self.name,)))
        if _preamble is None:
            preamble = (self.name,)
        else:
            preamble = _preamble + (self.name,)

        for field in self._fields:
            _field = field.extok
            return_dict.update(_field.return_all_fields(_preamble = preamble))

        if(self.finance_type != -1):
            return_dict[preamble] = self.finance_type
        return return_dict
    
    #USAGE: applies the dictionary to relevant Extended Type objects
    def apply_return_dict(self, return_dict, _preamble = None):
        if _preamble is None:
            preamble = (self.name, )
        else:
            preamble = _preamble + (self.name,)

        for field in self._fields:
            _field = field.extok
            _field.apply_return_dict(return_dict, _preamble = preamble)

        if(preamble in return_dict):
            self._finance_type = return_dict[preamble]




    def print_fields(self, print_extok = False):
        print(f"{self._name} Fields:")
        for field in self._fields:
            print(f"{field.name}")
            if(print_extok):
                print(field.extok)
        print("^^^")

    
    def is_undefined(self) -> bool:
        if(len(self._num_token_types) == 0 and len(self._den_token_types) == 0 and self._address == 'u'):# and self.finance_type == -1):
            return True
        return False

    def is_constant(self) -> bool:
        if(len(self._num_token_types) == 1 and len(self._den_token_types) == 1 and self._num_token_types[0] == -1 and self._den_token_types[0] == -1):
            return True
        return False


    def is_address(self) -> bool:
        if(self._address != 'u' and self._address != None):            
            return True 
        return False

    def token_type_clear(self):
        self.clear_num()
        self.clear_den()
        self._address = 'u'
        #self.norm = 'u'
        self.link_function = None
        #self._updated = False

    def init_constant(self):
        #if not(self.is_undefined()):
            #print("[W] Initializing defined variable to constant")
        self.token_type_clear()
        self.add_num_token_type(-1)
        self.add_den_token_type(-1)
        self.norm = 'u'
        self._updated = False

    @property
    def finance_type(self):
        return self._finance_type

    @property
    def pure_type(self):
        if(self._finance_type > update_start):
            return self._finance_type - update_start
        return self._finance_type

    @finance_type.setter
    def finance_type(self, f_type):
        if(f_type > update_start):
            self._updated = True
        else:
            self._updated = False
        self._finance_type = f_type

    @property
    def updated(self):
        return self._updated

    @updated.setter
    def updated(self, is_updated):
        self._updated = is_updated
        if(self._finance_type <= update_start and is_updated):
            self._finance_type += update_start

    def __str__(self):
        num_token_types_str = ", ".join(str(elem) for elem in self._num_token_types)
        den_token_types_str = ", ".join(str(elem) for elem in self._den_token_types)
        fields_str = ", ".join(str(elem.name) for elem in self._fields)
        if(self._updated == True):
            try:
                finance_type = "updated " + f_type_num[self._finance_type - update_start]
            except Exception as e:
                finance_type = f_type_num[self._finance_type]
        elif self._finance_type in f_type_num:
            finance_type = f_type_num[self._finance_type]
        else:
            finance_type = None
        if(self._finance_type_parent == None):
            finance_type_parent = "None (Orphan)"
        elif(isinstance(self._finance_type_parent, tuple)):
            finance_type_parent = "{}, {}".format(self._finance_type_parent[0].name, self.finance_type_parent[1].name)
        else:
            finance_type_parent = self._finance_type_parent.name
        return (
            f"\n"
            f"Name: {self._name} Function: {self._function_name}\n"
            f"Num: {num_token_types_str}\n"
            f"Den: {den_token_types_str}\n"
            f"Address: {self._address}\n"
            f"Norm: {self._norm}\n"
            f"LF: {self._linked_contract}\n"
            f"Value: {self._value}\n"
            f"Fields: {fields_str}\n"
            f"Finance Type: {finance_type}\n"
            f"Dir LLM Set: {self._llm_set}\n"
            f"Search counter: {self._searched_id}\n"
            f"Finance Type Parent: {finance_type_parent}"
        )
 
        
