#!/usr/bin/env python
# -*- coding: utf-8 -*-


# libraries and imports
import sys
sys.path.append('../')

import argparse
import os
import torch
from lib.network_architectures import PointNetCls, feature_transform_regularizer
from lib.helpers import load_3d_mnist
import tqdm

# small helper to train a pointnet model on 3DMINST

#We use a point net model to classify the point clouds. 
#The implementation is based on the code from https://github.com/fxia22/pointnet.pytorch/blob/master/pointnet/model.py


# script arguments
# parameter set up
parser = argparse.ArgumentParser(description = 'Training a pointnet model on 3DMNIST')

parser.add_argument('--nepochs', default=100, help="Number of training epochs", type=int)
parser.add_argument('--num_classes', default=10, help="Number of classes in the dataset (MNIST:10)", type=int)
parser.add_argument('--batch_size', default=32, help="training batch size", type=int)
parser.add_argument('--feature_transform', default=True, help="whether features should be transformed", type=int)

parser.add_argument('--source_dir', default="../drafts", help="Source directory to hte training data", type=str)

args=parser.parse_args()


# load the data
(test_loader, test_ds), (train_loader, train_ds) = load_3d_mnist(args.source_dir, args.batch_size, 
                                                                 num_points=1024, train=True, 
                                                                 seed=42)

# Parameters
nepochs = args.nepochs
feature_transform = args.feature_transform
num_classes = args.num_classes
num_batch = len(train_ds) / args.batch_size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

# Initialize model, optimizer and scheduler
classifier = PointNetCls(k=num_classes, feature_transform=feature_transform)
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Move model to device
classifier.to(device)

if __name__ == '__main__':

    # Training Loop
    for epoch in range(nepochs):
        scheduler.step()
        for i, data in enumerate(train_loader, 0):
            points, target = data
            points = points.transpose(2, 1) # B, N, C => B, C, N
            points, target = points.to(device), target.to(device)
            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans, trans_feat = classifier(points)
            loss = F.nll_loss(pred, target)
            if feature_transform:
                loss += feature_transform_regularizer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.data).cpu().sum()

        if epoch%10 == 0:
            # Evaluate a batch of test data
            j, data = next(enumerate(test_loader, 0))
            points, target = data
            points = points.transpose(2, 1)
            points, target = points.to(device), target.to(device)
            classifier = classifier.eval()
            pred, _, _ = classifier(points)
            loss = F.nll_loss(pred, target)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.data).cpu().sum()
            print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, 'test', loss.item(), correct.item()/float(args.batch_size)))


    # Evaluate the trained model
    total_correct = 0
    total_testset = 0
    for i,data in tqdm(enumerate(test_loader, 0)):
        points, target = data
        points = points.transpose(2, 1)
        points, target = points.to(device), target.to(device)
        classifier = classifier.eval()
        pred, _, _ = classifier(points)
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        total_correct += correct.item()
        total_testset += points.size()[0]

    print("final accuracy {}".format(total_correct / float(total_testset))) # Should be around 0.95

    # save the model

    torch.save(classifier.state_dict(), os.path.join(
        args.source_dir,'pointNet.pth'))