import numpy as np
import tensorflow as tf


def display(count_name, count, display_dict):
    # display function
    string_list = ["{}: {:.5e}".format(key, display_dict[key]) for key in display_dict.keys()]
    s = ",\t "
    print(count_name + " {}: ".format(count) + s.join(string_list))


class NeuralNet(object):
    def __init__(self, session, train_data, test_data=None, data_recorder=None):
        self._sess = session
        self._dr = data_recorder
        self._dtype = tf.float32

        # process input dataset
        self._train_data = train_data
        self._test_data = test_data
        self._train_sample_num = train_data['outputs'].shape[0]
        self._input_dim = train_data['inputs'].shape[1]
        self._output_dim = train_data['outputs'].shape[1]

        if self._test_data is not None:
            if (not self._test_data['inputs'].shape[1] == self._input_dim) or (
                    not self._test_data['outputs'].shape[1] == self._output_dim):
                print('test data has different input/output dimension than train data, discard test data')
                self._test_data = None

        # parameters for this object, sub-class may add more parameters
        self._value_name_list = ['learning_rate', 'loss_type', 'batch_size']

        self._learning_rate = None
        self._loss_type = None
        self._batch_size = None

        # parameters with default values for this object
        self._opt_value_name_list = ['init_scale', 'learning_rate_decay_every', 'learning_rate_decay']

        self._init_scale = 1e-2
        self._learning_rate_decay = 1.0
        self._learning_rate_decay_every = 10000

        if self._dr is not None:
            self._dr_train_loss_num = self._dr.init(channel_name='Train Loss')
            self._dr_train_accuracy_num = self._dr.init(channel_name='Train Accuracy')
            print("Record data:")
            print("Train loss at channel {}".format(self._dr_train_loss_num))
            print("Train accuracy at channel {}".format(self._dr_train_accuracy_num))

            if self._test_data is not None:
                self._dr_test_loss_num = self._dr.init(channel_name='Test Loss')
                self._dr_test_accuracy_num = self._dr.init(channel_name='Test Accuracy')
                print("Test loss at channel {}".format(self._dr_test_loss_num))
                print("Test accuracy at channel {}".format(self._dr_test_accuracy_num))

    def set_value(self, dic):
        for key in dic.keys():
            if hasattr(self, key):
                setattr(self, key, dic[key])

    def get_value(self, value_name):
        if value_name in self._value_name_list:
            return getattr(self, value_name)
        else:
            print("{} not found in parameters".format(value_name))
            return None

    def report_parameters(self):
        print("Parameters for object {}:".format(self.__name__))
        for value_name in self._value_name_list:
            print("- {}: {}".format(value_name, getattr(self, value_name)))
        for value_name in self._opt_value_name_list:
            print("- {}: {}".format(value_name, getattr(self, value_name)))

    def check_parameters(self):
        for value_name in self._value_name_list:
            if getattr(self, value_name) is None:
                print("Parameter {} is not set".format(value_name))
                return False
        return True

    def build_model(self):
        raise Exception('NeuralNet :: buildModel : Subclasses have to implement this function. Quitting...')

    def update(self, inputs_batch, outputs_batch, global_step_count):
        raise Exception('NeuralNet :: update : Subclasses have to implement this function. Quitting...')

    def get_loss(self, inputs, outputs):
        raise Exception('NeuralNet :: get_loss : Subclasses have to implement this function. Quitting...')

    def get_prediction_acc(self, inputs, outputs):
        raise Exception('NeuralNet :: get_prediction : Subclasses have to implement this function. Quitting...')

    def stop_crit_check(self, **kwargs):
        raise Exception('NeuralNet :: stop_crit_check : Subclasses have to implement this function. Quitting...')

    def train_rnd_batch(self, train_iters=50000, rnd_batch_size=False, stop_check_every=1, display_every=5,
                        test_every=100):
        print('Training by random batches...\n')

        if rnd_batch_size:
            sample_prob = float(self._batch_size) / self._train_sample_num
            print('Random batch size, every data has probability {.2f} to be selected'.format(sample_prob))

        for global_step_count in range(train_iters):
            # sample data batch
            perm = np.arange(self._train_sample_num)
            np.random.shuffle(perm)
            if rnd_batch_size:
                batch_size = np.max([np.random.binomial(self._train_sample_num, sample_prob), 1])
            else:
                batch_size = self._batch_size

            inputs_batch = self._train_data['inputs'][perm[range(batch_size)], :]
            outputs_batch = self._train_data['outputs'][perm[range(batch_size)], :]

            # gradient descent on batch loss
            batch_loss, grad_norm, lr, debug = self.update(inputs_batch, outputs_batch,
                                                           global_step_count)

            if (global_step_count % display_every is 0) or (global_step_count % stop_check_every is 0):
                # evaluate training loss whenever necessary
                loss = self.get_loss(self._train_data['inputs'], self._train_data['outputs'])
                if np.isnan(loss):
                    print("Quitting, loss value is nan")
                    return -1, self._dr

                # stop check
                if global_step_count % stop_check_every is 0:
                    if self.stop_crit_check(loss_val=loss, grad_norm=grad_norm):
                        return global_step_count, self._dr

                # training display
                if global_step_count % display_every is 0:
                    acc = self.get_prediction_acc(self._train_data['inputs'], self._train_data['outputs'])
                    if self._dr is not None:
                        self._dr.update(loss, global_step_count, self._dr_train_loss_num)
                        if acc is not None:
                            self._dr.update(acc, global_step_count, self._dr_train_accuracy_num)
                    display('Iteration', global_step_count,
                            {'loss': loss, 'batch loss': batch_loss, 'gradient norm': grad_norm, 'Learning rate': lr})
                    if debug:
                        display('Iteration', global_step_count,
                                {'debug info': debug})

            # test whenever necessary
            if (self._test_data is not None) and (test_every > 0) and (global_step_count % test_every is 0):
                # compute test loss
                test_loss = self.get_loss(self._test_data['inputs'], self._test_data['outputs'])

                if self._dr is not None:
                    self._dr.update(test_loss, global_step_count, self._dr_test_loss_num)
                print("Testing...")
                display('Test at Iteration', global_step_count,
                        {'loss': test_loss})
                # compute test prediction accuracy
                acc = self.get_prediction_acc(self._test_data['inputs'], self._test_data['outputs'])
                if acc is not None:
                    self._dr.update(acc, global_step_count, self._dr_test_accuracy_num)
                    display('Prediction at Iteration', global_step_count,
                            {'accuracy': acc})

        # maximum iteration reached
        print("Quitting, stopping criteria satisfied: maximum iteration count ({})".format(train_iters))
        return global_step_count, self._dr

    def train_mini_batch(self, train_epochs=50000, stop_check_every=1, display_every=5, test_every=100):
        print('Training by mini batches...\n')
        # global_step_count for update
        global_step_count = 0

        batch_num = self._train_sample_num // self._batch_size

        for epoch_count in range(train_epochs):

            max_loss_val = 0.0
            max_grad_norm = 0.0

            # shuffle the training data
            perm = np.arange(self._train_sample_num)
            np.random.shuffle(perm)
            inputs_perm = self._train_data['inputs'][perm, :]
            outputs_perm = self._train_data['outputs'][perm, :]

            # update through mini-batches for one epoch
            for batch_count in range(batch_num):

                # get data batch
                start_idx = batch_count * self._batch_size
                end_idx = min(self._train_sample_num, (1 + batch_count) * self._batch_size)
                inputs_batch = inputs_perm[start_idx:end_idx, :]
                outputs_batch = outputs_perm[start_idx:end_idx, :]

                # gradient descent on batch loss
                batch_loss, grad_norm, lr, debug = self.update(inputs_batch, outputs_batch,
                                                               global_step_count)

                # store maximum batch loss and batch gradient in one epoch
                if batch_loss > max_loss_val:
                    max_loss_val = batch_loss
                if grad_norm > max_grad_norm:
                    max_grad_norm = grad_norm

                # debug info
                if debug:
                    display('Iteration', global_step_count,
                            {'debug info': debug})

                # add 1 to global_step_count for every update
                global_step_count += 1

            if (epoch_count % display_every is 0) or (epoch_count % stop_check_every is 0):
                # evaluate training loss whenever necessary
                loss = self.get_loss(self._train_data['inputs'], self._train_data['outputs'])
                if np.isnan(loss):
                    print("Quitting, loss value is nan")
                    return -1, self._dr

                # stop check
                if epoch_count % stop_check_every is 0:
                    if self.stop_crit_check(loss_val=loss, batch_loss_val=max_loss_val, grad_norm=max_grad_norm):
                        return epoch_count, self._dr

                # training display
                if epoch_count % display_every is 0:
                    acc = self.get_prediction_acc(self._train_data['inputs'], self._train_data['outputs'])
                    if self._dr is not None:
                        self._dr.update(loss, epoch_count, self._dr_train_loss_num)
                        if acc is not None:
                            self._dr.update(acc, epoch_count, self._dr_train_accuracy_num)

                    display('Epoch', epoch_count,
                            {'loss': loss, 'max batch loss': max_loss_val, 'max gradient norm': max_grad_norm,
                             'Learning rate': lr})

            if (self._test_data is not None) and (test_every > 0) and (epoch_count % test_every is 0):
                # compute test loss
                test_loss = self.get_loss(self._test_data['inputs'], self._test_data['outputs'])

                if self._dr is not None:
                    self._dr.update(test_loss, epoch_count, self._dr_test_loss_num)
                print("Testing...")
                display('Test at Epoch', epoch_count,
                        {'loss': test_loss})
                # compute test prediction accuracy
                acc = self.get_prediction_acc(self._test_data['inputs'], self._test_data['outputs'])
                if acc is not None:
                    self._dr.update(acc, epoch_count, self._dr_test_accuracy_num)
                    display('Prediction at Epoch', epoch_count,
                            {'accuracy': acc})

        # maximum epoch count reached
        print("Quitting, stopping criteria satisfied: maximum epoch count ({})".format(train_epochs))
        return epoch_count, self._dr

    def test(self):
        loss = self.get_loss(self._test_data['inputs'], self._test_data['outputs'])
        acc = self.get_prediction_acc(self._test_data['inputs'], self._test_data['outputs'])

        return loss, acc

    # handlers for inputs validation
    @property
    def sess(self):
        return self._sess

    @sess.setter
    def sess(self, s):
        if type(s) is not tf.Session:
            raise Exception('Error: no tensorflow session')
        self._sess = s

    @property
    def train_data(self):
        return self._train_data

    @train_data.setter
    def train_data(self, dic):
        if type(dic) is not dict:
            raise Exception('Error: training data must be a dictionary')
        key = dic.keys()
        if ('outputs' not in key) or ('inputs' not in key):
            raise Exception('Error: training data must contain inputs and outputs')
        if (type(dic['outputs']) is not np.ndarray) or (type(dic['inputs']) is not np.ndarray):
            raise Exception('Error: inputs and outputs must be numpy arrays')
        if (len(dic['outputs'].shape) is not 2) or (len(dic['outputs'].shape) is not 2):
            raise Exception('Error: inputs and outputs must be numpy matrices')
        if not dic['outputs'].shape[0] == dic['inputs'].shape[0]:
            raise Exception('Error: inputs and outputs must have same number of rows(samples)')
        self._train_data = dic

    @property
    def test_data(self):
        return self._test_data

    @test_data.setter
    def test_data(self, dic):
        if dic is None:
            self._test_data = dic
        else:
            if type(dic) is not dict:
                raise Exception('Error: training data must be a dictionary')
            key = dic.keys()
            if ('outputs' not in key) or ('inputs' not in key):
                raise Exception('Error: training data must contain inputs and outputs')
            if (type(dic['outputs']) is not np.ndarray) or (type(dic['inputs']) is not np.ndarray):
                raise Exception('Error: inputs and outputs must be numpy arrays')
            if (len(dic['outputs'].shape) is not 2) or (len(dic['outputs'].shape) is not 2):
                raise Exception('Error: inputs and outputs must be numpy matrices')
            if not dic['outputs'].shape[0] == dic['inputs'].shape[0]:
                raise Exception('Error: inputs and outputs must have same number of rows(samples)')
            self._test_data = dic

    @property
    def batch_size(self):
        return self._batch_size

    @batch_size.setter
    def batch_size(self, bs):
        if (type(bs) is int) and (bs > 0):
            self._batch_size = bs
        else:
            raise Exception('Error: batch_size must be positive integer')

    @property
    def learning_rate(self):
        return self._learning_rate

    @learning_rate.setter
    def learning_rate(self, lr):
        if (type(lr) is float) and (lr > 0):
            self._learning_rate = lr
        else:
            raise Exception('Error: learning_rate must be positive float number')

    @property
    def loss_type(self):
        return self._loss_type

    @loss_type.setter
    def loss_type(self, lt):
        if lt in ['l2', 'soft_max']:
            self._loss_type = lt
        else:
            self._loss_type = 'l2'
            print("Unsupported loss type {}, set to default loss type (l2)".format(lt))

    @property
    def init_scale(self):
        return self._init_scale

    @init_scale.setter
    def init_scale(self, s):
        if (type(s) is float) and (s > 0):
            self._init_scale = s
        else:
            raise Exception('Error: init_scale must be positive float number')

    @property
    def learning_rate_decay(self):
        return self._learning_rate_decay

    @learning_rate_decay.setter
    def learning_rate_decay(self, lrd):
        if (type(lrd) is float) and (lrd > 0) and (lrd <= 1):
            self._learning_rate_decay = lrd
        else:
            raise Exception('Error: learning_rate_decay must be float number in (0,1]')

    @property
    def learning_rate_decay_every(self):
        return self._learning_rate_decay_every

    @learning_rate_decay_every.setter
    def learning_rate_decay_every(self, lrde):
        if (type(lrde) is int) and (lrde > 0):
            self._learning_rate_decay_every = lrde
        else:
            raise Exception('Error: learning_rate_decay_every must be positive integer')


if __name__ == '__main__':
    pass
