#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 15 17:19:34 2021

"""


import torch
import numpy as np
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import pickle
import torch.optim as optim
from tqdm import tqdm
import time

class Net(nn.Module):

    def __init__(self,neurons):
        super().__init__()
        self.fc1 = nn.Linear(2, neurons)
        self.bn1 = nn.BatchNorm1d(neurons)
        self.fc2 = nn.Linear(neurons, neurons)    
        self.bn2 = nn.BatchNorm1d(neurons)
        self.fc3 = nn.Linear(neurons, neurons) 
        self.bn3 = nn.BatchNorm1d(neurons)
        self.fc4 = nn.Linear(neurons, 1) 

    def forward(self,x):
        # x = F.relu(self.bn1(self.fc1(x)))
        # x = F.relu(self.bn2(self.fc2(x)))
        # x = F.relu(self.bn3(self.fc3(x)))
        # x = self.fc4(x)
        
        #print('x1 = ', x)
        x = F.relu(self.fc1(x))
        #print('x2 = ', x)
        x = F.relu(self.fc2(x))
        #print('x3 = ', x)
        x = F.relu(self.fc3(x))
        #print('x4 = ', x)
        x = F.relu(self.fc4(x))
        #print('x5 = ', x)        
        #print('final = ', x)
        return x
    
    
    
  
    
 
 
neurons = 512
net = Net(neurons)    
Num_traj = 100
Num_data_per_traj = 500

optimizer = optim.Adam(net.parameters(),lr=10**(-3))
EPOCHS = 50

X = torch.zeros( (Num_traj*Num_data_per_traj, 2) )
Y = torch.zeros( (Num_traj*Num_data_per_traj, 1) )

with open('shuffled_data.pkl','rb') as f:  # Python 3: open(..., 'rb')
    training_data = pickle.load(f)  

print('setting data..')
for ii in tqdm(range(len(training_data))):
    #print(training_data[ii])    
    cur_data = training_data[ii]
    q_norm = cur_data[0]
    q_dot_norm = cur_data[1] 
    t_norm = cur_data[2]
    u_norm = cur_data[3]

    x_input = np.concatenate((q_norm,q_dot_norm), axis=None)     

    X[ii,:] = torch.FloatTensor(x_input).view(-1,2)
    Y[ii,:] = torch.FloatTensor([u_norm])

print('training...')
BATCH_SIZE = 256    
for epoch in tqdm(range(EPOCHS)):
    total_loss = 0
    kk = 0
    for ii in tqdm(range(0, Num_traj*Num_data_per_traj, BATCH_SIZE)):

       
        batch_X = X[ii:ii+BATCH_SIZE].view(-1, 2)
        batch_Y = Y[ii:ii+BATCH_SIZE]        
        
        net.zero_grad() #resetting gradient
        output = net(batch_X)

        loss = F.mse_loss(output,batch_Y)
        loss.backward()
        total_loss += loss
        kk += 1
        optimizer.step()
        
        #print(loss)
  
    #print('output=', output)
    print('loss =', total_loss/kk)
    
torch.save(net, 'NeuralNetworkPendulum')
