import numpy as np
import math
from utils import adaptive_exploration, Split_X, house1
from collections import deque
from copy import copy

# accomplish the sparse residual tree
class srt_online:
    def __init__(self, WP, erro, fs=0, max_count=400) -> None:
        '''
        WP: some working parameters,
            WP[0] is the upper bound of condition of R
            WP[1] is the termination error of explorations
            WP[2] is the factor of shape parameters
            WP[3] is related to the max number of points to be trained
        erro: Termination condition
        fs: switch of sparse residual forest
        max_count: max number of points in the first layer
        '''
        self.WP = WP
        self.erro = erro
        self.fs = fs
        self.max_count = max_count

        self.TAE = 0
        self.TRMSE = 0
        self.n = 0
        self.tree = Node(WP=WP, erro=erro, fs=self.fs,max_count=max_count)

    def update(self, x, y) -> None:
        y = float(y)
        # update TAE and TRMSE
        self.TAE = max(self.TAE,abs(y))
        self.TRMSE = math.sqrt((self.TRMSE**2)*self.n+y**2/(self.n+1))
        self.n += 1

        tree, y = self._findleaf(self.tree, x, y)
        tree.add_data(x, y)
        if tree.total_num >= self.max_count / tree.depth and not (tree.is_trained()):
            tree.update(self.TAE, self.TRMSE)\
            
    # recursively find the leaf node where (x,y) belong to
    def _findleaf(self, tree, x, y):
        if tree.is_leaf():
            return tree, y
        elif np.dot(x, tree.nv) <= tree.center:
            RF = 0
            for i in range(tree.n):
                RF += tree.alpha[i] * np.exp(
                    -tree.SP2 * np.sum((x - tree.Xc[i]) ** 2)
                )
            y = y - RF
            return self._findleaf(tree.left, x, y)
        else:
            RF = 0
            for i in range(tree.n):
                RF += tree.alpha[i] * np.exp(
                    -tree.SP2 * np.sum((x - tree.Xc[i]) ** 2)
                )
            y = y - RF
            return self._findleaf(tree.right, x, y)

    def predict(self, x):
        y = 0
        return self._predict(self.tree, x, y)

    def _predict(self, tree, x, y):
        if tree.is_leaf() and not (tree.is_trained()):
            return y
        else:
            RF = 0
            for i in range(tree.n):
                RF += tree.alpha[i] * np.exp(
                    -tree.SP2 * np.sum((x - tree.Xc[i]) ** 2)
                )
            if tree.is_leaf() and tree.is_trained():
                return y + RF
            elif np.dot(x, tree.nv) <= tree.center:
                return self._predict(tree.left, x, y + RF)
            else:
                return self._predict(tree.right, x, y + RF)

    def draw(self):

        pass

# node class for sparse residual tree
class Node:
    def __init__(self, WP, X=[], Y=[], erro=0.01, depth=0, fs=0,max_count=200) -> None:
        self.WP = WP
        
        self.erro = erro
        self.total_num = len(Y)
        self.depth = depth + 1
        self.fs = fs
        self.max_count = max_count

        # deque can get the most recent information from input data
        self.X = deque(X,maxlen=int(self.max_count/self.depth))
        self.Y = deque(Y,maxlen=int(self.max_count/self.depth))

        # nv and center determine how to split the input space
        self.nv = 0
        self.center = 0

        self.n = 0
        self.n_max = 0
        self.P = 0
        self.alpha = []
        self.SP2 = 0
        self.Re_index = 0
        self.res = 0
        self.RAE = 0
        self.RRMSE = 0
        self.Xc = 0

        self.left = None
        self.right = None
        self.isleaf = True
        self.istrained = False

    def add_data(self, x, y):
        self.X.append(x)
        self.Y.append(y)
        self.total_num += 1
        if self.isleaf and self.istrained:
            res = 0
            for i in range(self.n):
                res += self.alpha[i]*np.exp(-self.SP2*np.linalg.norm(x-self.Xc[i])**2)
            if abs(res) > max(abs(self.res)):
                self.Xc = np.append(self.Xc,[x],axis=0)
                self.leaf_update()

    def is_leaf(self):
        return self.isleaf

    def is_trained(self):
        return self.istrained

    def update(self, TAE, TRMSE):
        (
            self.n,
            self.P,
            self.alpha,
            self.SP2,
            self.Re_index,
            self.res,
        ) = adaptive_exploration(self.X, self.Y, self.WP, self.WP[3]*1000)
        self.RAE = max(abs(self.res)) / TAE
        self.RRMSE = math.sqrt(np.mean(self.res ** 2)) / TRMSE
        self.istrained = True
        self.Xc = np.array(self.X)[self.P]
        self.n_max = int(1.2*self.n)
        if self.RAE > self.erro and self.depth < 8:
            self.isleaf = False
            self.nv, self.center = Split_X(self.X, self.Re_index, self.fs, self.res)
            XV = np.dot(self.X, self.nv)
            Left = np.where(XV <= self.center)[0]
            Right = np.where(XV > self.center)[0]
            self.left = Node(
                self.WP, np.array(self.X)[Left], self.res[Left], self.erro, self.depth, self.fs,self.max_count
            )
            self.right = Node(
                self.WP, np.array(self.X)[Right], self.res[Right], self.erro, self.depth,self.fs,self.max_count
            )

    def leaf_update(self):
        X = np.array(self.X)
        Y = np.array(self.Y)
        N = len(Y)
        GK = None
        U = None
        uY = copy(Y)
        beta = []
        for i in range(len(self.Xc)):
            u = np.exp(-self.SP2*np.sum((X-self.Xc[i])**2,axis = 1))
            U,uY,beta = house1(i,N,u,U,uY,beta)
            if GK is None:
                GK = u.reshape(-1,1)
            else:
                GK = np.concatenate((GK,u.reshape(-1,1)),axis=1)
        R = np.triu(U[0:len(self.Xc)][0:len(self.Xc)])
        self.alpha = np.linalg.solve(R,uY[0:len(self.Xc)])
        self.n += 1
        
        if self.n>self.n_max:
            # when center points is too much, drop one
            contributions = [] 
            for i in range(len(self.Xc)):
                contribution = GK[:,i]*self.alpha[i]
                contribution = np.linalg.norm(contribution)
                contributions.append(contribution)
            index_min = contributions.index(min(contributions))
            self.Xc = np.delete(self.Xc,index_min,axis=0)
            GK = None
            U = None
            uY = copy(Y)
            beta = []
            for i in range(len(self.Xc)):
                u = np.exp(-self.SP2*np.sum((X-self.Xc[i])**2,axis = 1))
                U,uY,beta = house1(i,N,u,U,uY,beta)
                if GK is None:
                    GK = u.reshape(-1,1)
                else:
                    GK = np.concatenate((GK,u.reshape(-1,1)),axis=1)
            R = np.triu(U[0:len(self.Xc)][0:len(self.Xc)])
            
            self.alpha = np.linalg.solve(R,uY[0:len(self.Xc)])
            self.n -= 1