#!/usr/bin/env python
# -*- coding: utf-8 -*-


# libraries and imports
import sys
sys.path.append('../')

import argparse
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
from sklearn.metrics import accuracy_score
import numpy as np
from lib.network_architectures import VoxelModel
from lib.helpers import load_3dVoxel_mnist
import tqdm

# small helper to train a voxelnet model on 3DMINST voxels

# The code is inspired from : https://www.kaggle.com/code/rajeevctrl/3d-mnist-classification-using-pointcloud-and-voxel#Create-Voxel-based-model

# script arguments
# parameter set up
parser = argparse.ArgumentParser(description = 'Training a pointnet model on 3DMNIST')

parser.add_argument('--nepochs', default=50, help="Number of training epochs", type=int)
parser.add_argument('--num_classes', default=10, help="Number of classes in the dataset (MNIST:10)", type=int)
parser.add_argument('--batch_size', default=32, help="training batch size", type=int)

parser.add_argument('--source_dir', default="../drafts", help="Source directory to hte training data", type=str)

args=parser.parse_args()


# load the data
(test_loader, test_ds), (train_loader, train_ds) = load_3dVoxel_mnist(args.source_dir, args.batch_size, 
                                                                 seed=42)

# Parameters
nepochs = args.nepochs
num_classes = args.num_classes
num_batch = len(train_ds) / args.batch_size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

# Initialize model, optimizer and scheduler
classifier = VoxelModel(n_out_classes=num_classes)
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.0001)
loss_fn = torch.nn.CrossEntropyLoss()

# Move model to device
classifier.to(device)


def train_on_epoch(model, dl_train, optimizer, loss_fn, device):
    model.train()
    model = model.to(device)
    losses = []
    N = len(dl_train)
    for i,(x,y) in enumerate(dl_train):
        x,y  = x.unsqueeze(1).to(device).float(), y.to(device).float()
        
        y_pred = model(x)
        loss = loss_fn(y_pred, y.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.detach().cpu().numpy())
    mean_loss = np.mean(losses)
    return mean_loss, loss

def test(model, dl_test, loss_fn, device):
    model.eval()
    model = model.to(device)
    losses = []
    with torch.no_grad():
        N = len(dl_test)
        for i,(x,y) in enumerate(dl_test):
            x,y  = x.unsqueeze(1).to(device).float(), y.to(device).float()
            
            y_pred = model(x)
            loss = loss_fn(y_pred, y.long())
            
            losses.append(loss.detach().cpu().numpy())
    mean_loss = np.mean(losses)
    return mean_loss


if __name__ == '__main__':

    test_losses = []
    # Training Loop
    for epoch in range(nepochs):
        train_loss, loss = train_on_epoch(classifier, train_loader, optimizer, loss_fn, device)

        test_loss = test(classifier, test_loader, loss_fn, device)
        test_losses.append(test_loss)
        print(f"\rEpoch:{epoch+1} train_loss:{train_loss} test_loss:{test_loss}")


    # Evaluate the trained model
    classifier.eval()

    error = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in test_loader:
            # test = Variable(images.permute(0,4,1,2,3)).cuda()
            test_data = Variable(images.unsqueeze(1)).cuda()
            outputs = classifier(test_data).detach()
            test_loss = error(outputs,labels.cuda())
            test_accuracy = accuracy_score(torch.argmax(outputs,axis = 1).cpu().numpy(),labels.cpu().numpy())
    print(test_accuracy)

    print("final accuracy {}".format(test_accuracy)) # Should be around 0.75

    # save the model
    torch.save(classifier.state_dict(), os.path.join(
        args.source_dir,'VoxelNet.pth'))
    





