import numpy as np
import copy
import pickle
import os
import random
from scipy import linalg
from sklearn.cluster import KMeans

def sigmoid(x):
    return 1/(1 + np.exp(-x))

def softmax(x):
    Z = np.sum(np.exp(x), axis = 1, keepdims = True)

    return np.exp(x)/Z

def relu(x):
    return np.maximum(x, np.zeros_like(x))

def norm2(x):
    return np.linalg.norm(x)**2

def pass_positive(x):
    if x >= 0.0:
        return x
    else:
        return 1.0

def dot(A, B):
    return linalg.blas.dgemm(alpha = 1.0, a = A, b = B)

def solve(A, b):
    AA = dot(A.T, A)
    bA = dot(b, A)
    D, U = np.linalg.eigh(AA)
    Ap = (U * np.sqrt(D)).T
    bp = dot(bA, U) / np.sqrt(D)

    return np.linalg.lstsq(Ap, bp, rcond = None)

class Input:
    def __init__(self, aranges = [], data = None, batches = 1, use_distr = False, random_batch = False, random_data = False, random_state = None):
        if isinstance(random_state, int):
                np.random.seed(random_state)

        if data is not None:
            if isinstance(data, list):
                self.data = np.array(data).T
            else:
                self.data = data.T

        if aranges:
            self.aranges = aranges
            axes = []
            for arange in self.aranges:
                axes.append(np.arange(arange[0], arange[1], arange[2]))

            data_ = []
            len_axis = axes[0].shape[0]
            for i in range(len(axes)):
                data_.append(np.tile(np.repeat(axes[i], len_axis**(len(axes) - i - 1)), len_axis**(i)))

            self.data = np.array(data_).T

        if random_data:
            np.random.shuffle(self.data)
                
        if use_distr:
            self.mean = (np.max(self.data, axis = 0) + np.min(self.data, axis = 0))/2
            self.std = np.sqrt((np.max(self.data, axis = 0) - np.min(self.data, axis = 0))**2/12)
        else:
            self.mean = np.mean(self.data, axis = 0)
            self.std = np.std(self.data, axis = 0)

        self.norm_data = (self.data - self.mean)/self.std

        self.batch_data = {}
        self.batch_norm_data = {}
        self.batches = batches
        if self.batches > 1:
            d = int(self.data.shape[0]/self.batches)
            for i in range(self.batches):
                if i == self.batches - 1:
                    self.batch_data[i] = self.data[d*i:, :]
                    self.batch_norm_data[i] = (self.data[d*i:, :] - self.mean)/self.std
                else:
                    self.batch_data[i] = self.data[d*i: d*(i+1), :]
                    self.batch_norm_data[i] = (self.data[d*i: d*(i+1), :] - self.mean)/self.std

        r_batch_data = {}
        r_batch_norm_data = {}
        if random_batch:
            idxs = list(self.batch_data.keys())
            random.shuffle(idxs, lambda : np.random.rand())

            for i in range(self.batches):
                idx = idxs[i]
                r_batch_data[i] = self.batch_data[idx]
                r_batch_norm_data[i] = self.batch_norm_data[idx]

            self.batch_data = r_batch_data
            self.batch_norm_data = r_batch_norm_data
    
    def update_intervals(self, intervals):
        if self.aranges:
            for i in range(len(intervals)):
                self.aranges[i][2] = intervals[i]

            axes = []
            for arange in self.aranges:
                axes.append(np.arange(arange[0], arange[1], arange[2]))

            data_ = []
            len_axis = axes[0].shape[0]
            for i in range(len(axes)):
                data_.append(np.tile(np.repeat(axes[i], len_axis**(len(axes) - i - 1)), len_axis**(i)))
                
            self.data = np.array(data_).T
            self.norm_data = (self.data - self.mean)/self.std

        return self.data

    def normalize(self, x):
        return (x - self.mean)/self.std

    def denorm(self, x):
        return self.mean + x * self.std

class Func:
    def __init__(self, expr):
        self.expr = expr
        self.out = None
        self.mean = None
        self.std = None
        self.norm_out = None
        
    def run(self, x, new_norm = True, use_distr = False):
        self.out = eval(self.expr).astype(float)

        if new_norm:
            if use_distr:
                self.mean = (np.max(self.out) + np.min(self.out))/2
                self.std = np.sqrt((np.max(self.out) - np.min(self.out))**2/12)
            else:
                self.mean = np.mean(self.out)
                self.std = np.std(self.out)

        self.norm_out = (self.out - self.mean)/self.std
        
        return self.out

    def run_batches(self, batch_data):
        batch_out = {}
        batch_norm_out = {}
        
        for i in batch_data.keys():
            x = batch_data[i]
            batch_out[i] = eval(self.expr).astype(float)
            batch_norm_out[i] = (batch_out[i] - self.mean)/self.std

        return batch_out, batch_norm_out

    def for_NN(self, batched = None, normalize = True):
        try:
            if normalize:
                self.norm_out.shape[1]
                return self.norm_out
            else:
                self.out.shape[1]
                return self.out
        except:
            if normalize:
                if batched:
                    batched_NN = {}
                    for b in batched.keys():
                        batched_NN[b] = np.expand_dims(batched[b], axis = 1)
                    return batched_NN
                else:
                    return np.expand_dims(self.norm_out, axis = 1)
            else:
                if batched:
                    batched_NN = {}
                    for b in batched.keys():
                        batched_NN[b] = np.expand_dims(batched[b], axis = 1)
                    return batched_NN
                else:
                    return np.expand_dims(self.out, axis = 1)

    def denorm(self, y):
        return self.mean + y * self.std

class Batch:
    def __init__(self, x, y, y_enc = None, batches = 2, p = 0.0, cluster = False, random_state = None):
        self.classes = sorted(set(list(y)))
        self.batches = batches

        if random_state is not None:
            np.random.seed(random_state)

        idxs = {}
        self.batch_data = {}
        self.add_data = {}
        init = True
        for c in self.classes:
            self.idxs_r = []
            idxs[int(c)] = np.where(y == c)[0]

            if cluster:
                X = x[idxs[int(c)], :]
                kmeans = KMeans(n_clusters = batches, random_state = random_state).fit(X)
                for i in range(self.batches):
                    if init:
                        self.batch_data[i] = {}
                        self.batch_data[i]['x'] = []
                        self.batch_data[i]['y'] = []

                        self.add_data[i] = {}
                        self.add_data[i]['x'] = []
                        self.add_data[i]['y'] = []
                        
                        if y_enc is not None:
                            self.batch_data[i]['y_enc'] = []
                            self.add_data[i]['y_enc'] = []

                    self.batch_data[i]['x'].extend(x[idxs[int(c)][kmeans.labels_ == i], :].tolist())
                    self.batch_data[i]['y'].extend(y[idxs[int(c)][kmeans.labels_ == i]].tolist())
                    if y_enc is not None:
                            self.batch_data[i]['y_enc'].extend(y_enc[idxs[int(c)][kmeans.labels_ == i], :].tolist())

                    if i > 0:
                        idxs_b = idxs[int(c)][kmeans.labels_ == i - 1]
                        topk = int(p*idxs_b.shape[0])
                        r = np.random.rand(idxs_b.shape[0])
                        self.idxs_r.extend(idxs_b[np.argpartition(r, topk)[:topk]])

                        self.add_data[i]['x'].extend(x[self.idxs_r, :].tolist())
                        self.add_data[i]['y'].extend(y[self.idxs_r].tolist())
                        if y_enc is not None:
                            self.add_data[i]['y_enc'].extend(y_enc[self.idxs_r, :].tolist())

                init = False
            else:
                d = int(idxs[int(c)].shape[0]/self.batches)
                for i in range(self.batches):
                    if init:
                        self.batch_data[i] = {}
                        self.batch_data[i]['x'] = []
                        self.batch_data[i]['y'] = []
                        if y_enc is not None:
                            self.batch_data[i]['y_enc'] = []
                    
                    if i == self.batches - 1:
                        self.batch_data[i]['x'].extend(x[idxs[int(c)][d*i:], :].tolist())
                        self.batch_data[i]['y'].extend(y[idxs[int(c)][d*i:]].tolist())
                        if y_enc is not None:
                            self.batch_data[i]['y_enc'].extend(y_enc[idxs[int(c)][d*i:], :].tolist())
                    else:
                        self.batch_data[i]['x'].extend(x[idxs[int(c)][d*i: d*(i+1)], :].tolist())
                        self.batch_data[i]['y'].extend(y[idxs[int(c)][d*i: d*(i+1)]].tolist())
                        if y_enc is not None:
                            self.batch_data[i]['y_enc'].extend(y_enc[idxs[int(c)][d*i: d*(i+1)], :].tolist())

                init = False

        for i in range(self.batches):
            self.batch_data[i]['x'] = np.array(self.batch_data[i]['x'])
            self.batch_data[i]['y'] = np.array(self.batch_data[i]['y'])

            if self.add_data[i]['x']:
                self.add_data[i]['x'] = np.array(self.add_data[i]['x'])
            if self.add_data[i]['y']:
                self.add_data[i]['y'] = np.array(self.add_data[i]['y'])
                
            if y_enc is not None:
                self.batch_data[i]['y_enc'] = np.array(self.batch_data[i]['y_enc'])
                if self.add_data[i]['y_enc']:
                    self.add_data[i]['y_enc'] = np.array(self.add_data[i]['y_enc'])

class NN:
    def __init__(self, n_inputs, n_hidden = [10, 1], activations = ['tanh', 'lin'], classify = False, w_range = [-1, 1], w_scale = False, compare = False, random_state = None):
        try:
            assert len(n_hidden) == len(activations)

            self.classify = classify
            self.n_layers = len(n_hidden)
            self.n_hidden = n_hidden
            self.activations = activations
            self.activations[-1] = 'lin'
            self.compare = compare
            
            self.A_sizes = {0: n_inputs}
            for l in range(1, self.n_layers + 1):
                self.A_sizes[l] = n_hidden[l-1]
            
            self.A = {l: None for l in range(self.n_layers + 1)}
            self.dA = {l: None for l in range(self.n_layers + 1)}
            self.n_patterns = 0
            
            self.g_scaler = 2
            self.lam_scaler = 2
            self.ret_scaler = 2
            self.lam = None
            self.lam_zeros = None
            self.retention = False
            self.shut_off = False

            self.X_ret = []
            self.Y_ret = []
            self.topk = 10

            if isinstance(random_state, int):
                np.random.seed(random_state)
            
            self.W = {}
            self.Wm = {}
            self.Wr = {}
            self.Wt = {}
            self.nW = 0
            self.nWt = 1
            for l in range(self.n_layers):
                if w_scale:
                    self.W[l] = w_range[0]/np.sqrt(self.A_sizes[l]) + (w_range[1] - w_range[0])*np.random.rand(self.A_sizes[l] + 1, self.A_sizes[l+1])/np.sqrt(self.A_sizes[l])
                else:
                    self.W[l] = w_range[0] + (w_range[1] - w_range[0])*np.random.rand(self.A_sizes[l] + 1, self.A_sizes[l+1])

                self.Wm[l] = np.ones_like(self.W[l])

                self.nW += np.prod(self.W[l].shape)

            if self.compare:
                with open("W.pkl", "wb") as fp:
                    pickle.dump(self.W, fp)

        except:
            print("\033[91mn_hidden and activations lengths do not match!\033[0m")

    def reload_init_weights(self):
        if os.path.exists("W.pkl"):
            with open("W.pkl", "rb") as fp:
                self.W = pickle.load(fp)

    def load_weights(self, file, retention = False):
        if os.path.exists(file):
            with open(file, "rb") as fp:
                self.W = pickle.load(fp)

        self.retention = retention
        if self.retention:
            self.Wr = copy.deepcopy(self.W)

    def predict(self, X, from_layer = 0):
        self.n_patterns = X.shape[0]
        self.A[from_layer] = np.hstack((np.ones((self.n_patterns, 1)), X))

        for l, a in enumerate(self.activations):
            if l >= from_layer:
                if a == 'tanh':
                    self.A[l+1] = np.tanh(np.dot(self.A[l], self.W[l]))
                    self.dA[l+1] = 1.0 - self.A[l+1]**2
                elif a == 'sig':
                    self.A[l+1] = sigmoid(np.dot(self.A[l], self.W[l]))
                    self.dA[l+1] = self.A[l+1]*(1.0 - self.A[l+1])
                elif a == 'lin':
                    self.A[l+1] = np.dot(self.A[l], self.W[l])
                    self.dA[l+1] = self.A[l+1]
                elif a == 'relu':
                    self.A[l+1] = relu(np.dot(self.A[l], self.W[l]))
                    self.dA[l+1] = (self.A[l+1] > 0.0).astype(float)

                if l < self.n_layers - 1:
                    self.A[l+1] = np.hstack((np.ones((self.n_patterns, 1)), self.A[l+1]))

        if self.classify:
            if self.n_hidden[-1] == 1:
                y = (sigmoid(self.A[self.n_layers]) >= 0.5).astype(float)
            elif self.n_hidden[-1] > 1:
                y = (self.A[self.n_layers] == np.max(self.A[self.n_layers], axis = 1, keepdims = True)).astype(float)
        else:
            y = self.A[self.n_layers]
        
        return y

    def compute_scaler(self, _type = 0):
        if _type == 1:
            J = np.sum(self.A[self.n_layers-1]**2)
            self.g_scaler = 1/J
        else:
            self.g_scaler = 1/max([self.n_patterns, self.n_hidden[-2]])

        if self.retention:
            self.ret_scaler = max([np.sum(np.dot(self.A[self.n_layers-1], 1.0 - self.Wm[self.n_layers-1]))/(self.n_patterns*np.sum(1.0 - self.Wm[self.n_layers-1])), 0.0])

    def compute_grads(self, err, until_layer = 0):
        Q = {}
        Q[self.n_layers-1] = err
        
        g = {}
        for l in reversed(range(self.n_layers)):
            g[l] = np.zeros_like(self.W[l], float)
            if l >= until_layer:
                g[l] = np.dot(self.A[l].T, Q[l])

                if l > 0:
                    Q[l-1] = np.dot(Q[l], self.W[l][1:].T) * self.dA[l]

        return g

    def norms(self, until_layer = 0):
        W_norm2 = 0
        for l in reversed(range(self.n_layers)):
            if l >= until_layer:
                W_norm2 += np.sum(self.W[l]**2)

                if self.retention:
                    self.Wt[l] = (self.W[l] - self.Wr[l]) * (1.0 - self.Wm[l])

        return W_norm2

    def ret_notify(self, until_layer = 0):
        at = 0
        bt = 0
        for l in reversed(range(self.n_layers)):
            if l >= until_layer:
                at += np.sum(1.0 - self.Wm[l])
                bt += np.sum(self.Wm[l])

        self.nWt = at

        print("No. of weights to train: " + str(int(bt)))

    def update(self, err, until_layer, lr = 1.0, reg = 0.0, ret = 0.0):
        g_err = self.compute_grads(err, until_layer)
        W_norm2 = self.norms(until_layer)

        nw = lr*self.g_scaler/self.n_hidden[-1]
        b0 = self.ret_scaler

        MSW = 0
        for l in range(self.n_layers):
            if self.retention:
                if self.shut_off:
                    self.W[l] = self.W[l] - nw*g_err[l]*self.Wm[l]
                else:
                    self.W[l] = self.W[l] - nw*g_err[l] - ret*b0*self.Wt[l]
            else:
                self.W[l] = self.W[l] - nw*g_err[l]

        MSW = 0.5*W_norm2/self.nW

        return MSW

    def calc_acc(self, y, Y):
        s = 0
        if self.n_hidden[-1] == 1:
            for y_, Y_ in zip(y, Y):
                s += (y_ == Y_)[0].astype(float)
        elif self.n_hidden[-1] > 1:
            for y_, Y_ in zip(y, Y):
                s += np.dot(y_.T, Y_)

        return s/Y.shape[0]

    def train(self, X_, Y_, X_a = None, Y_a = None, until_layer = 0, by = {'iter': 1, 'mse': None, 'acc': None}, lr = 1.0, reg = 0.0, ret = 0.0,
              lr_type = 0, border_pts = False, chkpt_file = None, contd = True):
        By = {'iter': 1, 'mse': None, 'acc': None}

        if 'iter' in by.keys():
            By['iter'] = by['iter']
        if 'mse' in by.keys():
            By['mse'] = by['mse']
        if 'acc' in by.keys():
            By['acc'] = by['acc']
        
        if not contd:
            self.reload_init_weights()

        if X_a is not None:
            X_ = np.vstack((X_, X_a))
        if Y_a is not None:
            Y_ = np.vstack((Y_, Y_a))

        if self.retention:
            self.ret_notify()
            if self.classify:
                if border_pts:
                    n_patterns = Y_.shape[0]
                    if self.X_ret:
                        X = np.vstack((X_, np.array(self.X_ret)))
                        Y = np.vstack((Y_, np.array(self.Y_ret)))
                    else:
                        X = X_
                        Y = Y_

                    print("Border patterns added: " + str(Y.shape[0] - n_patterns))
                else:
                    X = X_
                    Y = Y_
            else:
                X = X_
                Y = Y_
        else:
            if self.classify:
                if border_pts:
                    n_patterns = Y_.shape[0]
                    if self.X_ret:
                        X = np.vstack((X_, np.array(self.X_ret)))
                        Y = np.vstack((Y_, np.array(self.Y_ret)))
                    else:
                        X = X_
                        Y = Y_

                    print("Border patterns added: " + str(Y.shape[0] - n_patterns))
                else:
                    X = X_
                    Y = Y_
            else:
                X = X_
                Y = Y_
        
        if By['iter'] != None:
            if self.classify:
               for k in range(by['iter'] + 1):
                    y = self.predict(X)
                    self.compute_scaler(lr_type)
                    if self.n_hidden[-1] == 1:
                        err = sigmoid(self.A[self.n_layers]) - Y
                    elif self.n_hidden[-1] > 1:
                        err = softmax(self.A[self.n_layers]) - Y
                    ACC = self.calc_acc(y, Y)
                    MSE = 0.5*np.mean(err**2)
                    MSW = self.update(err, until_layer, lr, reg, ret)
                    if k % 100 == 0:
                        print('@iteration ' + str(k) + ': ACC = ' + str(ACC) + ", MSE = " + str(MSE) + ", MSW = " + str(MSW)) 
            else:
                for k in range(by['iter'] + 1):
                    y = self.predict(X)
                    self.compute_scaler(lr_type)
                    err = y - Y
                    MSE = 0.5*np.mean(err**2)
                    MSW = self.update(err, until_layer, lr, reg, ret)
                    if k % 100 == 0:
                        print('@iteration ' + str(k) + ': MSE = ' + str(MSE) + ", MSW = " + str(MSW)) 
        elif By['mse'] != None:
            MSE = 1.0
            k = 0
            if self.classify:
                while MSE > by['mse']:
                    y = self.predict(X)
                    self.compute_scaler(lr_type)
                    if self.n_hidden[-1] == 1:
                        err = sigmoid(self.A[self.n_layers]) - Y
                    elif self.n_hidden[-1] > 1:
                        err = softmax(self.A[self.n_layers]) - Y
                    ACC = self.calc_acc(y, Y)
                    MSE = 0.5*np.mean(err**2)
                    MSW = self.update(err, until_layer, lr, reg, ret)
                    if k % 100 == 0:
                        print('@iteration ' + str(k) + ': ACC = ' + str(ACC) + ", MSE = " + str(MSE) + ", MSW = " + str(MSW))
                    k += 1
            else:
                while MSE > by['mse']:
                    y = self.predict(X)
                    self.compute_scaler(lr_type)
                    err = y - Y
                    MSE = 0.5*np.mean(err**2)
                    MSW = self.update(err, until_layer, lr, reg, ret)
                    if k % 100 == 0:
                        print('@iteration ' + str(k) + ': MSE = ' + str(MSE) + ", MSW = " + str(MSW))
                    k += 1
        elif By['acc'] != None:
            ACC = 0.0
            k = 0
            while ACC < by['acc']:
                y = self.predict(X)
                self.compute_scaler(lr_type)
                if self.n_hidden[-1] == 1:
                    err = sigmoid(self.A[self.n_layers]) - Y
                elif self.n_hidden[-1] > 1:
                    err = softmax(self.A[self.n_layers]) - Y
                ACC = self.calc_acc(y, Y)
                MSE = 0.5*np.mean(err**2)
                MSW = self.update(err, until_layer, lr, reg, ret)
                if k % 100 == 0:
                    print('@iteration ' + str(k) + ': ACC = ' + str(ACC) + ", MSE = " + str(MSE) + ", MSW = " + str(MSW))
                k += 1

        g = self.compute_grads(np.ones_like(err))
        for l in g.keys():
            self.Wm[l] = self.Wm[l] * (g[l] == 0.0).astype(float)

        if self.classify:
            if border_pts:
                y_ = self.predict(X_)
                if self.n_hidden[-1] == 1:
                    err = sigmoid(self.A[self.n_layers]) - Y_
                elif self.n_hidden[-1] > 1:
                    err = softmax(self.A[self.n_layers]) - Y_
                ae = abs(err) * Y_
                for c in range(Y.shape[1]):
                    ret_idxs_mx = np.argpartition(ae[:, c], -self.topk)[-self.topk:]

                    dist = []
                    X_c = X_[Y_[:, c] == 1.0, :]
                    Y_c = Y_[Y_[:, c] == 1.0, :]
                    for i in range(X_c.shape[0]):
                        dist.append(np.min(np.sum((X_c[i, :] - X[ret_idxs_mx, :])**2, axis = 1)))

                    dist = np.array(dist)
                    ret_idxs_mn = np.argpartition(dist, -self.topk)[-self.topk:]

                    self.X_ret.extend(X_[ret_idxs_mx, :].tolist())
                    self.Y_ret.extend(Y_[ret_idxs_mx, :].tolist())
                    self.X_ret.extend(X_c[ret_idxs_mn, :].tolist())
                    self.Y_ret.extend(Y_c[ret_idxs_mn, :].tolist())
            

        if chkpt_file:
            with open(chkpt_file, "wb") as fp:
                pickle.dump(self.W, fp)







    
