import numpy as np


def saveNNet(model, nnetFile, means, ranges, inputMins, inputMaxes):
    model_params = model.get_weights()
    # Split the network parameters into weights and biases, assuming they alternate
    weights = model_params[0:len(model_params):2]
    biases = model_params[1:len(model_params):2]

    # Convert the file
    writeNNet(weights, biases, inputMins, inputMaxes, means, ranges, nnetFile)


def writeNNet(weights, biases, inputMins, inputMaxes, means, ranges, fileName):
    """
    Write network data to the .nnet file format
    Args:
        weights (list): Weight matrices in the network order 
        biases (list): Bias vectors in the network order
        inputMins (list): Minimum values for each input
        inputMaxes (list): Maximum values for each input
        means (list): Mean values for each input and a mean value for all outputs. Used to normalize inputs/outputs
        ranges (list): Range values for each input and a range value for all outputs. Used to normalize inputs/outputs
        fileName (str): File where the network will be written
    """

    with open(fileName, "w") as f2:
        # First, we write the header lines:
        # The first line written is just a line of text
        # The second line gives the four values:
        #     Number of fully connected layers in the network
        #     Number of inputs to the network
        #     Number of outputs from the network
        #     Maximum size of any hidden layer
        # The third line gives the sizes of each layer, including the input and output layers
        # The fourth line gives an outdated flag, so this can be ignored
        # The fifth line specifies the minimum values each input can take
        # The sixth line specifies the maximum values each input can take
        #     Inputs passed to the network are truncated to be between this range
        # The seventh line gives the mean value of each input and of all outputs
        # The eighth line gives the range of each input and of all outputs
        #     These two lines are used to map raw inputs to the 0 mean, unit range of the inputs and outputs
        #     used during training
        # The ninth line begins the network weights and biases
        f2.write("// Neural Network File Format by Kyle Julian, Stanford 2016\n")

        # Extract the necessary information and write the header information
        numLayers = len(weights)
        inputSize = weights[0].shape[0]
        outputSize = len(biases[-1])
        maxLayerSize = inputSize

        # Find maximum size of any hidden layer
        for b in biases:
            if len(b) > maxLayerSize:
                maxLayerSize = len(b)

        # Write data to header
        f2.write("%d,%d,%d,%d,\n" % (numLayers, inputSize, outputSize, maxLayerSize))
        f2.write("%d," % inputSize)
        for b in biases:
            f2.write("%d," % len(b))
        f2.write("\n")
        f2.write("0,\n")  # Unused Flag

        # Write Min, Max, Mean, and Range of each of the inputs and outputs for normalization
        f2.write(
            ",".join(str(inputMins[i]) for i in range(inputSize)) + ",\n"
        )  # Minimum Input Values
        f2.write(
            ",".join(str(inputMaxes[i]) for i in range(inputSize)) + ",\n"
        )  # Maximum Input Values
        f2.write(
            ",".join(str(means[i]) for i in range(inputSize + 1)) + ",\n"
        )  # Means for normalizations
        f2.write(
            ",".join(str(ranges[i]) for i in range(inputSize + 1)) + ",\n"
        )  # Ranges for noramlizations

        ##################
        # Write weights and biases of neural network
        # First, the weights from the input layer to the first hidden layer are written
        # Then, the biases of the first hidden layer are written
        # The pattern is repeated by next writing the weights from the first hidden layer to the second hidden layer,
        # followed by the biases of the second hidden layer.
        ##################
        for w, b in zip(weights, biases):
            for j in range(w.shape[1]):
                for i in range(w.shape[0]):
                    f2.write(
                        "%.5e," % w[i][j]
                    )  # Five digits written. More can be used, but that requires more more space.
                f2.write("\n")

            for i in range(len(b)):
                f2.write(
                    "%.5e,\n" % b[i]
                )  # Five digits written. More can be used, but that requires more more space.
