from bipointnet_cls import BiPointNetCls
from bipointnet2 import BiPointNet2SSGCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import urllib
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
torch.backends.cudnn.enabled = False


# from dataset import ModelNet

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='bipointnet')
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32)
args = parser.parse_args()

num_workers = args.num_workers
batch_size = args.batch_size

data_filename = 'modelnet40_normal_resampled.zip'
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
    get_download_dir(), 'modelnet40_normal_resampled')
if not os.path.exists(local_path):
    download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
             download_path, verify_ssl=False)
    from zipfile import ZipFile
    with ZipFile(download_path) as z:
        z.extractall(path=get_download_dir())

CustomDataLoader = partial(
    DataLoader,
    num_workers=num_workers,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True)


def train(net, opt, scheduler, train_loader, dev):

    net.train()

    total_loss = 0
    num_batches = 0
    total_correct = 0
    count = 0
    loss_f = nn.CrossEntropyLoss()
    with tqdm.tqdm(train_loader, ascii=True) as tq:
        for data, label in tq:
            data = data.data.numpy()
            data = provider.random_point_dropout(data)
            data[:, :, 0:3] = provider.random_scale_point_cloud(
                data[:, :, 0:3])
            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
            data = torch.tensor(data)
            label = label[:, 0]

            num_examples = label.shape[0]
            data, label = data.to(dev), label.to(dev).squeeze().long()
            opt.zero_grad()
            logits = net(data)
            loss = loss_f(logits, label)
            loss.backward()
            opt.step()

            _, preds = logits.max(1)

            num_batches += 1
            count += num_examples
            loss = loss.item()
            correct = (preds == label).sum().item()
            total_loss += loss
            total_correct += correct

            tq.set_postfix({
                'AvgLoss': '%.5f' % (total_loss / num_batches),
                'AvgAcc': '%.5f' % (total_correct / count)})
    scheduler.step()


def evaluate(net, test_loader, dev):
    net.eval()

    total_correct = 0
    count = 0

    with torch.no_grad():
        with tqdm.tqdm(test_loader, ascii=True) as tq:
            for data, label in tq:
                label = label[:, 0]
                num_examples = label.shape[0]
                data, label = data.to(dev), label.to(dev).squeeze().long()
                logits = net(data)
                _, preds = logits.max(1)

                correct = (preds == label).sum().item()
                total_correct += correct
                count += num_examples

                tq.set_postfix({
                    'AvgAcc': '%.5f' % (total_correct / count)})

    return total_correct / count


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.model == 'bipointnet':
    net = BiPointNetCls(40, input_dims=6)
elif args.model == 'bipointnet2_ssg':
    net = BiPointNet2SSGCls(40, batch_size, input_dims=6)

net = net.to(dev)
if args.load_model_path:
    net.load_state_dict(torch.load(args.load_model_path, map_location=dev))

opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)

scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)

train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)

best_test_acc = 0

for epoch in range(args.num_epochs):
    train(net, opt, scheduler, train_loader, dev)
    if (epoch + 1) % 1 == 0:
        print('Epoch #%d Testing' % epoch)
        test_acc = evaluate(net, test_loader, dev)
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            if args.save_model_path:
                torch.save(net.state_dict(), args.save_model_path)
        print('Current test acc: %.5f (best: %.5f)' % (
            test_acc, best_test_acc))
