"""
Train set pooling networks on problematic datasets to test 
the ability of PointNet and DeepSets to learn
"""

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split

import matplotlib.pyplot as plt
import time
import os

from models import SimpleSetPoolingNet
from datasets import RingLine, DiscSquare

def main():
    # Training hyperparameters (epochs, batch, learning rate)
    N_epochs, batch_size, lr = 60, 32, 0.1
    cldsize = 1500
    numclds = 2000*2
    
    # Set random seeds
    np.random.seed(43)
    torch.manual_seed(43)
        
    for dataset in ['rl','ds']:
        # Generate data + Train/Val split
        if dataset == 'rl':
            data = RingLine(class_size = numclds//2, cldsize=cldsize, 
                            jitter=0.3, relsize=0.7, rotate=False, sand=True)
        if dataset == 'ds':
            data = DiscSquare(class_size = numclds//2, cldsize=cldsize, relsize=0.4)
            
        lengths = [int(numclds*0.8), int(numclds*0.2)]
        train_data, val_data = random_split(data, lengths)
        
        # Create data loaders
        train_loader = DataLoader(dataset=train_data, batch_size=batch_size, 
                                  shuffle=True)
        val_loader = DataLoader(dataset=val_data, batch_size=batch_size, 
                                shuffle=True)
        
        # Create Set Pooling Net
        shape = [2,32,32,32,2]
        maxnet = SimpleSetPoolingNet(shape,'max')
        avenet = SimpleSetPoolingNet(shape,'ave')
            
        # Count number of trainable parameters (same for both)
        N_params = sum(p.numel() for p in maxnet.parameters() if p.requires_grad)
        
        # Choose loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optim_max = optim.SGD(maxnet.parameters(), lr=lr)
        optim_ave = optim.SGD(avenet.parameters(), lr=lr)
        
        
        # Create train_step with above, model, criterion & optimizer
        step_max = make_step(maxnet, criterion, optim_max)
        step_ave = make_step(avenet, criterion, optim_ave)
        
        # Train & validate model, and test running time
        start = time.time()
        losses_max = trainval(step_max, train_loader, val_loader, N_epochs) 
        elapsed_max = time.time()-start
        
        start = time.time()
        losses_ave = trainval(step_ave, train_loader, val_loader, N_epochs) 
        print('Elapsed time for PointNet: ', elapsed_max)
        print('Elapsed time for DeepSets: ', time.time() - start)
        
        # Make needed directories to save images
        directory = 'inductive_bias' + '_cldsize' + str(cldsize) 
        directory = directory + '_epochs' + str(N_epochs) 
        if not os.path.exists(directory):
            os.makedirs(directory)     
        
        # Create, save, & plot training curve
        epoch_list = np.arange(1, N_epochs+1)
        plt.figure()
        plt.plot(epoch_list, losses_max[0], label="PointNet")
        plt.plot(epoch_list, losses_ave[0], label="DeepSets")
        plt.axhline(y=0.0, color='grey', linestyle='-')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Validation Cross-Entropy Loss for Set Pooling Net\n' 
                  + f'# of parameters = {N_params}')
        plt.savefig(directory + '/' + dataset + '_training_curve_maxave.png')
        plt.show()
        
        
        with torch.no_grad():  
            # Visualize accuracy of predictions
            dataiter = iter(val_loader)
            example, label = dataiter.next()
            _, pred_label = torch.max(maxnet(example), dim=1)
            fig, axs = plt.subplots(4,4, figsize=(10,10))
            axs = axs.reshape(16)
            for i in range(16):
                axs[i].scatter(example[i,:,0], example[i,:,1],marker='.')
                axs[i].set_xlim(-1.1,1.1)
                axs[i].set_ylim(-1.1,1.1)
                #axs[i].set_aspect('equal', 'datalim')
                if pred_label[i] == label[i]:
                    axs[i].set_title('Predicted label %d' %pred_label[i], color='green')
                else:
                    axs[i].set_title('Predicted label %d' %pred_label[i], color='red')
            for ax in axs.flat:
               ax.label_outer() # hide inner ticks and labels
            plt.savefig(directory + '/eyeball_metric.png')
            plt.show()
            
        
            if dataset == 'rl':
                # Plot Example Classes for SandDataset
                fig, ax = plt.subplots(1,2, figsize=(10,10))
                ax[0].scatter(data.rings[0,:,0], data.rings[0,:,1], 
                              color = 'blue', alpha=0.6)
                ax[1].scatter(data.lines[0,:,0], data.lines[0,:,1], 
                              color = 'red', alpha=0.6)
                ax[0].set_aspect('equal')
                ax[1].set_aspect('equal')
                fig.suptitle('Example of Ring (blue), Line (red)\n' 
                             + f'N = {data.cldsize}',y=0.65)
                fig.tight_layout(rect=[0, 0.03, 1, 0.8])
                plt.savefig(directory + '/rl_example_pair.png')
                plt.show()
            
            elif dataset == 'ds':
                # Plot Example Classes for DiscInShape
                fig, ax = plt.subplots(1,2, figsize=(10,10))
                ax[0].scatter(data.discs[0,:,0], data.discs[0,:,1], 
                              color = 'blue', alpha=0.6)
    
                ax[1].scatter(data.squares[0,:,0], data.squares[0,:,1], 
                              color = 'red', alpha=0.6)
                ax[0].set_aspect('equal')
                ax[1].set_aspect('equal')
                fig.suptitle('Example of Disc (blue), Square (red)\n' 
                             + f'N = {data.cldsize}',y=0.65)
                fig.tight_layout(rect=[0, 0.03, 1, 0.8])
                plt.savefig(directory + '/ds_example_pair.png')
                plt.show()


def make_step(model, criterion, optimizer):  
    # Return function that steps through one model evaluation
    def step(x,y, train=True):
        if train == True:
            model.train()  # Puts model in training mode
            optimizer.zero_grad()
            
            # forward + backwards + optimize
            yhat = model(x)
            loss = criterion(yhat, y)
            loss.backward()  # Backprop to compute gradient
            optimizer.step()  # Use gradient to optimize with optimizer
            return loss.item()
        else:
            model.eval()  # Puts model in evalutaiton mode
            yhat = model(x)
            loss = criterion(yhat, y)
            return loss.item()
    return step

def trainval(step, train_loader, val_loader, N_epochs=20):
    # Train
    losses = []
    val_losses = []
    log_time = 20
    
    for epoch in range(N_epochs):
        
        num_items = 0.0
        cumulative_loss = 0.0
        for batch_idx, (x, y) in enumerate(train_loader):
            
            batch_size = x.size(0)
            
            loss = step(x,y, train=True)
            
            num_items += batch_size
            cumulative_loss +=  loss * batch_size
            running_loss = cumulative_loss/num_items
            
            if batch_idx % log_time == 0 and batch_idx != 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), running_loss)) 
                
        losses.append(running_loss)
        
        with torch.no_grad():
            
            cumulative_loss = 0.0
            for x_val, y_val in val_loader:
                
                batch_size = x_val.size(0)
                
                val_loss = step(x_val, y_val, train=False)
                cumulative_loss +=  val_loss * batch_size
                              
            running_loss = cumulative_loss/len(val_loader.dataset)     
            val_losses.append(running_loss) 
                
    return losses, val_losses
    

if __name__ == '__main__':
    main()
