
# from w import *
import matplotlib.pyplot as plt
import numpy as np






class BinaryTree:
    def __init__(self,value):
        self.value=value
        self.left=None
        self.right=None
        self.indexInHeap=None
        # self.attr = {'shape': 'square'}
        self.attr = 'square'
        self.il = self.insertLeft
        self.ir = self.insertRight
        self.l = self.getLeftChild
        self.r = self.getRightChild

    def getIndexHeap(self):
        return self.indexInHeap
    def setIndexHeap(self,value):
        self.indexInHeap=value
    def insertLeft(self,newNode):
        if self.left==None:
            self.left=BinaryTree(newNode)
        else:
            t=BinaryTree(newNode)
            t.left=self.left
            self.left=t
        return self.left
    def insertRight(self,newNode):
        if self.right==None:
            self.right=BinaryTree(newNode)
        else:
            t=BinaryTree(newNode)
            t.right=self.right
            self.right=t
        return self.right
    def getValue(self):
        return self.value
    def setValue(self,value):
        self.value=value
    def getRightChild(self):
        return self.right
    def getLeftChild(self):
        return self.left













def viz_tree(root, expr='', shrink=0.9,fontsize=20, color=(0.1, 0.7, 0.7)):
    # my modification on visualize tree.
    def getNumberOfLeafs(tree):
        leafNumber = 0
        if tree:
            if  not tree.getRightChild() and not tree.getLeftChild():
                leafNumber+=1
            else:
                leafNumber+=getNumberOfLeafs(tree.getLeftChild())
                leafNumber += getNumberOfLeafs(tree.getRightChild())
        return leafNumber
    def getDeepOfTree(tree):
        maxDeep=0
        if tree:
            maxLeft=getDeepOfTree(tree.getLeftChild())
            maxRight=getDeepOfTree(tree.getRightChild())
            if maxLeft>maxRight:
                maxDeep=maxLeft
                return 1+maxLeft
            else:
                maxDeep=maxRight
                return 1+maxRight
        return maxDeep

    def printTreePre(tree):
        if tree:
            print(tree.getValue())
            printTreePre(tree.getLeftChild())
            printTreePre(tree.getRightChild())
    def getDuiIndex(tree):
        if tree:
            if tree.getLeftChild():
                tree.getLeftChild().setIndexHeap(2*tree.getIndexHeap())
            if tree.getRightChild():
                tree.getRightChild().setIndexHeap(2 * tree.getIndexHeap()+1)
            getDuiIndex(tree.getLeftChild())
            getDuiIndex(tree.getRightChild())

    def getCuoHelp(list,tree):
        if tree:
            list[tree.getIndexHeap()]=tree.attr
            getCuoHelp(list,tree.getLeftChild())
            getCuoHelp(list,tree.getRightChild())
        return list
    def getCuoList(tree):
        getDuiIndex(tree)
        h=getDeepOfTree(tree)
        res=[None for i in range(2**h-1)]
        res.insert(0,0)

        return getCuoHelp(res,tree)


    def getDuiHelp(list,tree):
        if tree:
            list[tree.getIndexHeap()]=tree.getValue()
            getDuiHelp(list,tree.getLeftChild())
            getDuiHelp(list,tree.getRightChild())
        return list
    def getDuiList(tree):
        getDuiIndex(tree)
        h=getDeepOfTree(tree)
        res=[None for i in range(2**h-1)]
        res.insert(0,0)

        return getDuiHelp(res,tree)
    def drawBinaryTree(r, shrink=1,size=20, nodeType=dict(boxstyle="round", fc=(1.0, 0.7, 0.7), ec="none"), ax1 = plt.subplot(111, frameon=False)):

        # fig = plt.figure(1, facecolor="white")
        # fig.clf()
        
        h = getDeepOfTree(r)
        if h==1:
            ax1.annotate(r.getValue(), va="center", ha="center", xy=(0.5,0.5), bbox=nodeType)
            return None
        w = getNumberOfLeafs(r)
        detH = 1 / (h - 1)
        yAxis = []
        startY = 0
        while startY <= 1:
            yAxis.append(startY)
            startY += detH
        allLeafs = 2 ** (h - 1)
        detX = 1 / (allLeafs - 1)
        leafPos = []
        startX = 0
        while startX <= 1:
            leafPos.append(startX)
            startX += detX
        allXList = []
        while len(leafPos) >= 1:
            allXList.append(leafPos)
            tempList = []
            i = 0
            while i < len(leafPos) - 1:
                tempList.append((leafPos[i] + leafPos[i + 1]) / 2)
                i += 2
            leafPos = tempList
        allXList = allXList[::-1]
        finPos = []
        for xList, y in zip(allXList, yAxis[::-1]):
            for item in xList:
                finPos.append([item, y])
            






        # ------ applying shrink ------
        finPos = np.array(finPos)
        mean = finPos.mean(axis=0,keepdims=True)
        finPos = mean + (finPos-mean)*shrink
        finPos = finPos.tolist()
        

        # ------ modify location: for n-th row, the last index is: 2^(n+1)-2, e.g., from the first downward is: 0, 2, 6, 14, 30, 62, 126... 
        def apply_SymReL_comb2_fig_july24():
            finPos[11]=finPos[12]
            finPos[20]=finPos[19]
            finPos[19]=finPos[16]
            finPos[22]=finPos[24]
        apply_SymReL_comb2_fig_july24()

        def apply_SymReL_comb3_fig_july24():
            finPos[29]=finPos[28]
            finPos[27]=finPos[19]
            finPos[28]=finPos[25]
            
            finPos[59]=finPos[49]
            finPos[61]=finPos[58]

            finPos[123]=finPos[93]
            finPos[124]=finPos[110]
            finPos[125]=finPos[120]
            
        # apply_SymReL_comb3_fig_july24()




        # ------ original codes ------
        finPos.insert(0, 0)
        r.setIndexHeap(1)
        duiListForR = getDuiList(r)
        cuoListForR = None
        # cuoListForR = getCuoList(r)
        for i in range(1, len(duiListForR)):
            if duiListForR[i]:

                if 2*i<len(duiListForR) and duiListForR[2*i]:
                    # if cuoListForR is not None and cuoListForR[i]=='square':
                    #     nodeType['boxstyle']=['square','diamond'][0] # 没得dimond这个模式……
                    # else:
                    #     nodeType['boxstyle']='circle'
                    ax1.annotate("", xy=(finPos[i][0], finPos[i][1]),xytext=(finPos[2*i][0],finPos[2*i][1]),va="center", ha="center",bbox=nodeType,arrowprops=dict(arrowstyle="->"), size=size)##画出这个点


                if 2*i+1<len(duiListForR) and duiListForR[2*i+1]:
                    ax1.annotate("", xy=(finPos[i][0], finPos[i][1]),xytext=(finPos[2*i+1][0],finPos[2*i+1][1]),va="center", ha="center",bbox=nodeType,arrowprops=dict(arrowstyle="->"), size=size)##画出这个点
        for i in range(1, len(duiListForR)):
            if duiListForR[i]:
                ax1.annotate(duiListForR[i],va="center", ha="center",xy= (finPos[i][0], finPos[i][1]),bbox=nodeType, size=size)##画出这个点
        # plt.show()
        # plt.savefig('mytree.pdf')
        plt.axis('off')

    def plot_tree(root, ax=plt.subplot(111, frameon=False),shrink=0.9,fontsize=20):
        boxstyle = ['round',
                'square',
                'circle',
                ][1]
        ec = ['none',
          'k',
          (1, 0.5, 1, 0.5)][0] # ec is the color of the countour; k means black
        
        leafNod=dict(boxstyle=boxstyle, fc=color, ec=ec)  # want change color, change fc
        drawBinaryTree(root, nodeType=leafNod, ax1=ax, shrink=shrink, size=fontsize)
        plt.savefig('mytree.pdf')

    print(expr)
    ax=plt.subplot(111, frameon=False)
    ax.cla()
    plot_tree(root,shrink=shrink,fontsize=fontsize)
    print()
    return





def expJuly24_SymReL_comb1():
    r = BinaryTree("BOOL: ROCK-is\n-in-same-layer")
    n = r.insertLeft("BOOL:contain-excircle\n-of-rock-obj")
    n=n.insertLeft('get-all-crossed-objs')
    n=n.ir('get-horizontal-line\n-anchored-to-man')
    # n.insertRight('no-rok-sam')
    # n.attr='circle'
    # n.r().ir('sdas')
    # r.insertRight("*")
    
    viz_tree(r,shrink=0.9,fontsize=15,color=(0.1, 0.7, 0.7))
    plt.savefig('comb1.pdf', bbox_inches='tight')

    return

def expJuly24_SymReL_comb2():
    print('need to change in w: search apply_SymReL_comb2_fig_july24')
    r = BinaryTree("+:final action options")
    n = r.insertLeft("x")
    n.insertLeft('BOOL:rock-is-in\n-same-layer') # 
    n2=n=n.insertRight('+: run-away\noptions')
    n.il('x')
    n=n.l()
    n.il('BOOL:has-\nupward-\nladder') # HAS-ladder
    n.ir('action:\ngo up\nladder')
    n=n2
    n=n.ir('x')
    n.il('NOT')
    n.l().il('BOOL:has-\nupward-ladder')
    n.ir('action:\ngo down')
    n=r.ir('x')
    n.il('NOT')
    n.l().il('BOOL:rock-is-in\n-same-layer')
    n.ir('action:go up\nladder')
    
    
    viz_tree(r,shrink=1,fontsize=11,color=(0.4, 0.5, 0.7))
    plt.savefig('comb2.pdf', bbox_inches='tight')

    return



def expJuly24_SymReL_comb3():
    r = BinaryTree('+:go up ladder options')
    n=r.il('x')
    n.il('BOOL:\nupper-layer\n-is-clear')
    n.ir('action:\nup-neaest')
    n=r.ir('x')
    n.il('NOT')
    n.l().il('BOOL:upper-\nlayer-is-clear')
    n=n.ir('+:safely up options')
    n.il('x')
    n.l().il('BOOL:upper-rock-\npassed-nearest-ladder')
    n.l().ir('action:\nup-nearest')
    n=n.ir('x')
    n.il('NOT')
    n.l().il('BOOL:upper-\nrock-past-leftmost')
    n=n.ir('+')
    n.il('x')
    n.l().il('BOOL:upper-\nrock-past-leftmost')
    n.l().ir('action:\nup-leftmost')
    n=n.ir('x')
    n.il('NOT')
    n.l().il('BOOL:upper-\nrock-past-\nleftmost-ladder')
    n.ir('wait')
    
    
    viz_tree(r,shrink=1,fontsize=10,color=(0.2, 0.6, 0.8))
    plt.savefig('comb3.pdf', bbox_inches='tight')





expJuly24_SymReL_comb2()
# expJuly24_SymReL_comb1()
# expJuly24_SymReL_comb3()
















