#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 30 21:49:53 2021

"""


import torch
import numpy as np
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import pickle
import torch.optim as optim
from tqdm import tqdm
import time

class Net(nn.Module):

    def __init__(self,neurons):
        super().__init__()
        self.fc1 = nn.Linear(2, neurons)
        self.bn1 = nn.BatchNorm1d(neurons)
        self.fc2 = nn.Linear(neurons, neurons)    
        self.bn2 = nn.BatchNorm1d(neurons)
        self.fc3 = nn.Linear(neurons, neurons) 
        self.bn3 = nn.BatchNorm1d(neurons)
        self.fc4 = nn.Linear(neurons, 1) 

    def forward(self,x):
        # x = F.relu(self.bn1(self.fc1(x)))
        # x = F.relu(self.bn2(self.fc2(x)))
        # x = F.relu(self.bn3(self.fc3(x)))
        # x = self.fc4(x)
        
        #print('x1 = ', x)
        x = F.relu(self.fc1(x))
        #print('x2 = ', x)
        x = F.relu(self.fc2(x))
        #print('x3 = ', x)
        x = F.relu(self.fc3(x))
        #print('x4 = ', x)
        x = F.relu(self.fc4(x))
        #print('x5 = ', x)        
        #print('final = ', x)
        return x
    
net = torch.load('NeuralNetworkPendulum')   

#try_input = torch.FloatTensor( np.array([0.1386,0.1493,0.2575,0.8407,0.2543,0.8143,0.2435,0.9293,0.3500,0.1966,0.2511,0.6160,0.4733])) 
#print(net(try_input))

x = torch.randn(128, 1, 2, requires_grad=True)
net.eval()

torch.onnx.export(net,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "NeuralNetworkPendulum.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})    