# -*- coding: utf-8 -*-
"""
Centor-of-Mass training (on CPU)
"""

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split

def make_step(model, criterion, optimizer):  
    # Define function that steps through one model evaluation
    # If train==True, then parameters update. Else they don't.
    def step(x,y, train=True):
        if train == True:
            model.train()
            optimizer.zero_grad()
            yhat = model(x)
            loss = criterion(yhat, y)
            loss.backward()
            optimizer.step()
            return loss.item()
        else:
            model.eval()
            yhat = model(x)
            loss = criterion(yhat, y)
            return loss.item()
    
    return step
            
def trainval(step, train_loader, val_loader, N_epochs=20):
    # Train
    losses = []
    val_losses = []
    log_time = 100
    
    for epoch in range(N_epochs):
        
        num_items = 0.0
        cumulative_loss = 0.0
        for batch_idx, (x, y) in enumerate(train_loader):
            
            batch_size = x.size(0)
            
            loss = step(x,y, train=True)
            
            num_items += batch_size
            cumulative_loss +=  loss * batch_size
            running_loss = cumulative_loss/num_items
            
            if batch_idx % log_time == 0 and batch_idx != 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), running_loss)) 
                
                #cumulative_loss = 0.0
                
        losses.append(running_loss)
        
        with torch.no_grad():
            
            cumulative_loss = 0.0
            for x_val, y_val in val_loader:
                
                batch_size = x_val.size(0)
                
                val_loss = step(x_val, y_val, train=False)
                cumulative_loss +=  val_loss * batch_size
                              
            running_loss = cumulative_loss/len(val_loader.dataset)     
            val_losses.append(running_loss) 
                
    
    return losses, val_losses
    
    
if __name__ == '__main__':
    # Change parameters here. Training is tracked. Final model is saved.
    from datasets import CoMdata
    from models import SimpleSetPoolingNet
    
    # Set random seeds
    np.random.seed(42)
    torch.manual_seed(42)
    
    N_train_clds = 10**6
    N_epochs = 50
    batch_size = 64
    
    # Generate Center of Mass data, train/val split, and create loaders
    data = CoMdata(numclds=N_train_clds)
    
    lengths = [int(len(data)*0.8), int(len(data)*0.2)]
    train_data, val_data = random_split(data, lengths)
    
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True)
    
    # Create PointNet model (change architecture here)
    shape = [2, 500, 500, 500, 2]
    model = SimpleSetPoolingNet(shape, pool='max') # Make a PointNet
    N_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Choose learning rate, criterion (loss fn) and optimizer
    lr = 0.01
    fn = nn.MSELoss(reduction='sum') 
    def criterion(x,y):  
        # Divide by batch size instead of number of elements in tensor
        return fn(x,y)/x.size(0)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    
    # Create train_step with above, model, criterion & optimizer
    step = make_step(model, criterion, optimizer)
    
    # Train & validate model, and test running time
    import time
    start = time.time()
    losses, val_losses = trainval(step, train_loader, val_loader, N_epochs) 
    elapsed = time.time() - start
    print('Elapsed time: ', elapsed)
    
    # Save model
    import os
    
    directory = 'numclds_' + str(data.numclds)
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    model_path = directory + '/CoMptnet%d-%d-%d_%depoch_%dbs.pkl' % (*shape[1:-1],N_epochs,batch_size)
    torch.save(model, model_path) 
    
    # Create, plot & save training curve   
    import matplotlib.pyplot as plt
    
    epoch_list = np.arange(1, N_epochs+1)
    
    plt.figure()
    fig_path = directory + '/tc-ptnet%d_%d-%d-%d_%depoch%dbs.png' % (N_params,*shape[1:-1],N_epochs,batch_size)
    plt.plot(epoch_list, losses, label="Training Loss")
    plt.plot(epoch_list, val_losses, label="Validation Loss")
    plt.axhline(y=0.0, color='grey', linestyle='-')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Curve for PointNet(%d-%d-%d) \n (%d epochs, batchsize=%d)' 
              %(*shape[1:-1], N_epochs, batch_size))
    plt.legend()
    plt.savefig(fig_path)
    plt.show()
    
    