from sympy import *
import sympy
import numpy as np
from kan.MultKAN import MultKAN
import torch

def next_nontrivial_operation(expr, scale=1, bias=0):
    '''
    remove the affine part of an expression

    Args:
    -----
        expr : sympy expression
        scale : float
        bias : float

    Returns:
    --------
        expr : sympy expression
        scale : float
        bias : float

    Example
    -------
    >>> from kan.compiler import *
    >>> from sympy import *
    >>> input_vars = a, b = symbols('a b')
    >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
    >>> next_nontrivial_operation(expression)
    '''
    if expr.func == Add or expr.func == Mul:
        n_arg = len(expr.args)
        n_num = 0
        n_var_id = []
        n_num_id = []
        var_args = []
        for i in range(n_arg):
            is_number = expr.args[i].is_number
            n_num += is_number
            if not is_number:
                n_var_id.append(i)
                var_args.append(expr.args[i])
            else:
                n_num_id.append(i)
        if n_num > 0:
            # trivial
            if expr.func == Add:
                for i in range(n_num):
                    if i == 0:
                        bias = expr.args[n_num_id[i]]
                    else:
                        bias += expr.args[n_num_id[i]]
            if expr.func == Mul:
                for i in range(n_num):
                    if i == 0:
                        scale = expr.args[n_num_id[i]]
                    else:
                        scale *= expr.args[n_num_id[i]]
            
            return next_nontrivial_operation(expr.func(*var_args), scale, bias)
        else:
            return expr, scale, bias
    else:
        return expr, scale, bias
    

def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False):
    '''
    compile a symbolic formula to a MultKAN

    Args:
    -----
        input_variables : a list of sympy symbols
        expr : sympy expression
        grid : int
            the number of grid intervals
        k : int
            spline order
        auto_save : bool
            if auto_save = True, models are automatically saved

    Returns:
    --------
        MultKAN

    Example
    -------
    >>> from kan.compiler import *
    >>> from sympy import *
    >>> input_vars = a, b = symbols('a b')
    >>> expression = exp(sin(pi*a) + b**2)
    >>> model = kanpiler(input_vars, expression)
    >>> x = torch.rand(100,2) * 2 - 1
    >>> model(x)
    >>> model.plot()
    '''
    class Node:
        def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None):
            self.expr = expr
            self.mult_bool = mult_bool
            if self.mult_bool:
                self.mult_arity = mult_arity
            self.depth = depth

            if len(Nodes) <= depth:
                Nodes.append([])
                index = 0
            else:
                index = len(Nodes[depth])

            Nodes[depth].append(self)

            self.index = index
            if parent == None:
                self.parent_index = None
            else:
                self.parent_index = parent.index
            self.child_index = []

            # update parent's child_index
            if parent != None:
                parent.child_index.append(self.index)


            self.scale = scale
            self.bias = bias


    class SubNode:
        def __init__(self, expr, depth, scale, bias, parent=None):
            self.expr = expr
            self.depth = depth

            if len(SubNodes) <= depth:
                SubNodes.append([])
                index = 0
            else:
                index = len(SubNodes[depth])

            SubNodes[depth].append(self)

            self.index = index
            self.parent_index = None # shape: (2,)
            self.child_index = [] # shape: (n, 2)

            # update parent's child_index
            parent.child_index.append(self.index)

            self.scale = scale
            self.bias = bias


    class Connection:
        def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None):
            # connection = activation function that connects a subnode to a node in the next layer node
            self.affine = affine #[1,0,1,0] # (a,b,c,d)
            self.fun = fun # y = c*fun(a*x+b)+d
            self.fun_name = fun_name
            self.parent_index = parent.index
            self.depth = parent.depth
            self.child_index = child.index
            self.power_exponent = power_exponent # if fun == Pow
            Connections[(self.depth,self.parent_index,self.child_index)] = self
            
    def create_node(expr, parent=None, n_layer=None):
        #print('before', expr)
        expr, scale, bias = next_nontrivial_operation(expr) 
        #print('after', expr)
        if parent == None:
            depth = 0
        else:
            depth = parent.depth


        if expr.func == Mul:
            mult_arity = len(expr.args)
            node = Node(expr, True, depth, scale, bias, parent=parent, mult_arity=mult_arity)
            # create mult_arity SubNodes, + 1
            for i in range(mult_arity):
                # create SubNode
                expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
                subnode = SubNode(expr_i, node.depth+1, scale, bias, parent=node)
                if expr_i.func == Add:
                    for j in range(len(expr_i.args)):
                        expr_ij, scale, bias = next_nontrivial_operation(expr_i.args[j])
                        # expr_ij is impossible to be Add, should be Mul or 1D
                        if expr_ij.func == Mul:
                            #print(expr_ij)
                            # create a node with expr_ij
                            new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
                            # create a connection which is a linear function
                            c = Connection([1,0,float(scale),float(bias)], lambda x: x, 'x', parent=subnode, child=new_node)

                        elif expr_ij.func == Symbol:
                            #print(expr_ij)
                            new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
                            c = Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)

                        else:
                            # 1D function case
                            # create a node with expr_ij.args[0]
                            new_node = create_node(expr_ij.args[0], parent=subnode, n_layer=n_layer)
                            # create 1D function expr_ij.func
                            if expr_ij.func == Pow:
                                power_exponent = expr_ij.args[1]
                            else:
                                power_exponent = None
                            Connection([1,0,float(scale),float(bias)], expr_ij.func, fun_name = expr_ij.func, parent=subnode, child=new_node, power_exponent=power_exponent)


                elif expr_i.func == Mul:
                    # create a node with expr_i
                    new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
                    # create 1D function, linear
                    Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)

                elif expr_i.func == Symbol:
                    new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
                    Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)

                else:
                    # 1D functions
                    # create a node with expr_i.args[0]
                    new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
                    # create 1D function expr_i.func
                    if expr_i.func == Pow:
                        power_exponent = expr_i.args[1]
                    else:
                        power_exponent = None
                    Connection([1,0,1,0], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)

        elif expr.func == Add:

            node = Node(expr, False, depth, scale, bias, parent=parent)
            subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)

            for i in range(len(expr.args)):
                expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
                if expr_i.func == Mul:
                    # create a node with expr_i
                    new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
                    # create a connection which is a linear function
                    Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)

                elif expr_i.func == Symbol:
                    new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
                    Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)

                else:
                    # 1D function case
                    # create a node with expr_ij.args[0]
                    new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
                    # create 1D function expr_i.func
                    if expr_i.func == Pow:
                        power_exponent = expr_i.args[1]
                    else:
                        power_exponent = None
                    Connection([1,0,float(scale),float(bias)], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)

        elif expr.func == Symbol:
            # expr.func is a symbol (one of input variables)
            if n_layer == None:
                node = Node(expr, False, depth, scale, bias, parent=parent)
            else:
                node = Node(expr, False, depth, scale, bias, parent=parent)
                return_node = node
                for i in range(n_layer - depth):
                    subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
                    node = Node(expr, False, subnode.depth, 1, 0, parent=subnode)
                    Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=node)
                node = return_node

            Start_Nodes.append(node)

        else:
            # expr.func is 1D function
            #print(expr, scale, bias)
            node = Node(expr, False, depth, scale, bias, parent=parent)
            expr_i, scale, bias = next_nontrivial_operation(expr.args[0])
            subnode = SubNode(expr_i, node.depth+1, 1, 0, parent=node)
            # create a node with expr_i.args[0]
            new_node = create_node(expr.args[0], parent=subnode, n_layer=n_layer)
            # create 1D function expr_i.func
            if expr.func == Pow:
                power_exponent = expr.args[1]
            else:
                power_exponent = None
            Connection([1,0,1,0], expr.func, fun_name = expr.func, parent=subnode, child=new_node, power_exponent=power_exponent)

        return node

    Nodes = [[]]
    SubNodes = [[]]
    Connections = {}
    Start_Nodes = []

    create_node(expr, n_layer=None)

    n_layer = len(Nodes) - 1

    Nodes = [[]]
    SubNodes = [[]]
    Connections = {}
    Start_Nodes = []

    create_node(expr, n_layer=n_layer)

    # move affine parameters in leaf nodes to connections
    for node in Start_Nodes:
        c = Connections[(node.depth,node.parent_index,node.index)]
        c.affine[0] = float(node.scale)
        c.affine[1] = float(node.bias)
        node.scale = 1.
        node.bias = 0.
        
    #input_variables = symbol
    node2var = []
    for node in Start_Nodes:
        for i in range(len(input_variables)):
            if node.expr == input_variables[i]:
                node2var.append(i)

    # Nodes 
    n_mult = []
    n_sum = []
    for layer in Nodes:
        n_mult.append(0)
        n_sum.append(0)
        for node in layer:
            if node.mult_bool == True:
                n_mult[-1] += 1
            else:
                n_sum[-1] += 1

    # depth
    n_layer = len(Nodes) - 1

    # converter
    # input tree node id, output kan node id (distinguish sum and mult node)
    # input tree subnode id, output tree subnode id
    # node id
    subnode_index_convert = {}
    node_index_convert = {}
    connection_index_convert = {}
    mult_arities = []
    for layer_id in range(n_layer+1):
        mult_arity = []
        i_sum = 0
        i_mult = 0
        for i in range(len(Nodes[layer_id])):
            node = Nodes[layer_id][i]
            if node.mult_bool == True:
                kan_node_id = n_sum[layer_id] + i_mult
                arity = len(node.child_index)
                for i in range(arity):
                    subnode = SubNodes[node.depth+1][node.child_index[i]]
                    kan_subnode_id = n_sum[layer_id] + np.sum(mult_arity) + i
                    subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
                i_mult += 1
                mult_arity.append(arity)
            else:
                kan_node_id = i_sum
                if len(node.child_index) > 0:
                    subnode = SubNodes[node.depth+1][node.child_index[0]]
                    kan_subnode_id = i_sum
                    subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
                i_sum += 1

            if layer_id == n_layer:
                # input layer
                node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(node2var[kan_node_id]))
            else:
                node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(kan_node_id))

            # node: depth (node.depth -> n_layer - node.depth)
            #       width (node.index -> kan_node_id)
            # subnode: depth (subnode.depth -> n_layer - subnode.depth)
            #          width (subnote.index -> kan_subnode_id)
        mult_arities.append(mult_arity)

    for index in list(Connections.keys()):
        depth, subnode_id, node_id = index
        # to int(n_layer-depth), 
        _, kan_subnode_id = subnode_index_convert[(depth, subnode_id)]
        _, kan_node_id = node_index_convert[(depth, node_id)]
        connection_index_convert[(depth, subnode_id, node_id)] = (n_layer-depth, kan_subnode_id, kan_node_id)


    n_sum.reverse()
    n_mult.reverse()
    mult_arities.reverse()
    
    width = [[n_sum[i], n_mult[i]] for i in range(len(n_sum))]
    width[0][0] = len(input_variables)

    # allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc.
    model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save)
    
    # clean the graph
    for l in range(model.depth):
        for i in range(model.width_in[l]):
            for j in range(model.width_out[l+1]):
                model.fix_symbolic(l,i,j,'0',fit_params_bool=False)
                
    # Nodes
    Nodes_flat = [x for xs in Nodes for x in xs]

    self = model

    for node in Nodes_flat:
        node_depth = node.depth
        node_index = node.index
        kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)]
        #print(kan_node_depth, kan_node_index)
        if kan_node_depth > 0:
            self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale)
            self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias)
            
    
    # SubNodes
    SubNodes_flat = [x for xs in SubNodes for x in xs]

    for subnode in SubNodes_flat:
        subnode_depth = subnode.depth
        subnode_index = subnode.index
        kan_subnode_depth, kan_subnode_index = subnode_index_convert[(subnode_depth,subnode_index)]
        #print(kan_subnode_depth, kan_subnode_index)
        self.subnode_scale[kan_subnode_depth].data[kan_subnode_index] = float(subnode.scale)
        self.subnode_bias[kan_subnode_depth].data[kan_subnode_index] = float(subnode.bias)
        
    # Connections
    Connections_flat = list(Connections.values())

    for connection in Connections_flat:
        c_depth = connection.depth
        c_j = connection.parent_index
        c_i = connection.child_index
        kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)]

        # get symbolic fun_name
        fun_name = connection.fun_name
        #if fun_name == Pow:
        #    print(connection.power_exponent)

        if fun_name == 'x':
            kfun_name = 'x'
        elif fun_name == exp:
            kfun_name = 'exp'
        elif fun_name == sin:
            kfun_name = 'sin'
        elif fun_name == cos:
            kfun_name = 'cos'
        elif fun_name == tan:
            kfun_name = 'tan'
        elif fun_name == sqrt:
            kfun_name = 'sqrt'
        elif fun_name == log:
            kfun_name = 'log'
        elif fun_name == tanh:
            kfun_name = 'tanh'
        elif fun_name == asin:
            kfun_name = 'arcsin'
        elif fun_name == acos:
            kfun_name = 'arccos'
        elif fun_name == atan:
            kfun_name = 'arctan'
        elif fun_name == atanh:
            kfun_name = 'arctanh'
        elif fun_name == sign:
            kfun_name = 'sgn'
        elif fun_name == Pow:
            alpha = connection.power_exponent
            if alpha == Rational(1,2):
                kfun_name = 'x^0.5'
            elif alpha == - Rational(1,2):
                kfun_name = '1/x^0.5'
            elif alpha == Rational(3,2):
                kfun_name = 'x^1.5'
            else:
                alpha = int(connection.power_exponent)
                if alpha > 0:
                    if alpha == 1:
                        kfun_name = 'x'
                    else:
                        kfun_name = f'x^{alpha}'
                else:
                    if alpha == -1:
                        kfun_name = '1/x'
                    else:
                        kfun_name = f'1/x^{-alpha}'

        model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False)
        model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine)
        
    return model


sf2kan = kanpiler = expr2kan