#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 15 17:19:34 2021

"""


import torch
import numpy as np
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import pickle
import torch.optim as optim
from tqdm import tqdm
import time

class Net(nn.Module):

    def __init__(self,neurons):
        super().__init__()
        self.fc1 = nn.Linear(5, neurons)
        self.bn1 = nn.BatchNorm1d(neurons)
        self.fc2 = nn.Linear(neurons, neurons)    
        self.bn2 = nn.BatchNorm1d(neurons)
        self.fc3 = nn.Linear(neurons, neurons) 
        self.bn3 = nn.BatchNorm1d(neurons)
        self.fc4 = nn.Linear(neurons, 2) 

    def forward(self,x):
        
        #print('x1 = ', x)
        x = F.relu(self.fc1(x))
        #print('x2 = ', x)
        x = F.relu(self.fc2(x))
        #print('x3 = ', x)
        x = F.relu(self.fc3(x))
        #print('x4 = ', x)
        x = F.relu(self.fc4(x))
        #print('x5 = ', x)        
        #print('final = ', x)
        return x
    
    
    
   
 
 
neurons = 512
net = Net(neurons)    
Num_traj = 100
Num_data_per_traj = 500

optimizer = optim.Adam(net.parameters(),lr=0.001)
EPOCHS = 50

X = torch.zeros( (Num_traj*Num_data_per_traj, 5) )
Y = torch.zeros( (Num_traj*Num_data_per_traj, 2) )

with open('shuffled_data.pkl','rb') as f:  # Python 3: open(..., 'rb')
    training_data = pickle.load(f)  

print('setting data..')
for ii in tqdm(range(len(training_data))):

    cur_data = training_data[ii]
    q_norm = cur_data[0:3]
    q_dot_norm = cur_data[3:5] 
    t_norm = cur_data[5]
    u_norm = cur_data[6:8]

    x_input = np.concatenate((q_norm,q_dot_norm), axis=None)     

    X[ii,:] = torch.FloatTensor(x_input).view(-1,5)
    Y[ii,:] = torch.FloatTensor(u_norm)

print('training...')
BATCH_SIZE = 256    
for epoch in tqdm(range(EPOCHS)):
    total_loss = 0
    kk = 0
    for ii in tqdm(range(0, Num_traj*Num_data_per_traj, BATCH_SIZE)):
       
        batch_X = X[ii:ii+BATCH_SIZE].view(-1, 5)
        batch_Y = Y[ii:ii+BATCH_SIZE]        
        
        net.zero_grad() #resetting gradient
        output = net(batch_X)
        loss = F.mse_loss(output,batch_Y)
        loss.backward()
        total_loss += loss
        kk += 1
        optimizer.step()
        
        #print(loss)
  
    #print('output=', output)
    print(total_loss/kk)
    
torch.save(net, 'NeuralNetworkUnicycle')
