#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 15 17:19:34 2021

"""


import torch
import numpy as np
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import pickle
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm
import time
class Net(nn.Module):

    def __init__(self,neurons):
        super().__init__()
        self.fc1 = nn.Linear(12, neurons)
        self.bn1 = nn.BatchNorm1d(neurons)
        self.fc2 = nn.Linear(neurons, neurons)    
        self.bn2 = nn.BatchNorm1d(neurons)
        self.fc3 = nn.Linear(neurons, neurons) 
        self.bn3 = nn.BatchNorm1d(neurons)
        self.fc4 = nn.Linear(neurons, 6) 

    def forward(self,x):
        # x = F.relu(self.bn1(self.fc1(x)))
        # x = F.relu(self.bn2(self.fc2(x)))
        # x = F.relu(self.bn3(self.fc3(x)))
        # x = self.fc4(x)
        
        #print('x1 = ', x)
        x = F.relu(self.fc1(x))
        #print('x2 = ', x)
        x = F.relu(self.fc2(x))
        #print('x3 = ', x)
        x = F.relu(self.fc3(x))
        #print('x4 = ', x)
        x = F.relu(self.fc4(x))
        #print('x5 = ', x)        
        #print('final = ', x)
        return x
    
    

neurons = 512

Num_traj = 100
Num_data_per_traj = 500


EPOCHS = 50


with open('shuffled_data.pkl','rb') as f:  # Python 3: open(..., 'rb')
    training_data = pickle.load(f)  

print('setting data..')



net = Net(neurons)    
optimizer = optim.Adam(net.parameters(),lr=.001)

X = torch.zeros( (Num_traj*Num_data_per_traj, 12) )
Y = torch.zeros( (Num_traj*Num_data_per_traj, 6) )

for ii in tqdm(range(len(training_data))):
    #print(training_data[ii])
    #time.sleep(4)
    cur_data = training_data[ii][0]


    q_norm = cur_data[0:6]
    q_dot_norm = cur_data[6:12]  
    t_norm = cur_data[12]
    u_norm = cur_data[13:19]


    x_input = np.concatenate((q_norm,q_dot_norm), axis=None)     

    X[ii,:] = torch.FloatTensor(x_input).view(-1,12)
    Y[ii,:] = torch.FloatTensor(u_norm)

print('training...')
BATCH_SIZE = 256   

torch.set_printoptions(threshold=10_000)
 

for epoch in tqdm(range(EPOCHS)):
    total_loss = 0
    kk=0
    for ii in tqdm(range(0, Num_traj*Num_data_per_traj, BATCH_SIZE)):

       
        batch_X = X[ii:ii+BATCH_SIZE].view(-1, 12)
        batch_Y = Y[ii:ii+BATCH_SIZE]        
        
        optimizer.zero_grad() #resetting gradient
        output = net(batch_X)

        loss = F.mse_loss(output,batch_Y)
        loss.backward()
        total_loss += loss
        kk+= 1
        optimizer.step()        

        
    print('loss = ', total_loss/(kk))
    
torch.save(net, 'NeuralNetworkUR5')
