import sys
import os
import argparse
import datetime
import time
import numpy as np
import torch
import torch.optim as optim

from model.model_Hebbian import PointCloudClassifer
from data import get_dataloader
from train import train

def main():
	""" START argument setting """
	parser = argparse.ArgumentParser()
	parser.add_argument("--device", type=int, default=0)
	parser.add_argument("--num_workers", type=int, default=4)
	parser.add_argument("--seed", type=int, default=5)
	parser.add_argument("--lr", type=float, default=1e-3, help="1e-3 for ModelNet10 and ModelNet40")
	parser.add_argument("--epoch", type=int, default=100)
	parser.add_argument("--train_batch", type=int, default=4, help="32 for supervised and untrained, 4 for unsupervised")
	parser.add_argument("--test_batch", type=int, default=32)
	parser.add_argument("--dataset", type=str, default="ModelNet10", help="ModelNet10, ModelNet40, MNIST")
	parser.add_argument("--model", type=str, default="unsupervised", help="unsupervised, supervised")
	parser.add_argument("--encoder_train", type=bool, default=False, help="allow gradient flow to the encoder")
	parser.add_argument("--task", type=str, default="classification", choices=["classification", "reconstruction"])

	parser.add_argument("--lr_hebb", type=float, default=1 / 100, help="default: 1e-2, but change it depending on the amount of data")
	parser.add_argument("--k", type=int, default=1, help="k-WTA")
	parser.add_argument("--rank_num_point", type=int, default=0)
	parser.add_argument("--rank_num_neuron", type=int, default=0)
	parser.add_argument("--delta_num_point", type=float, default=1.0)
	parser.add_argument("--delta_num_neuron", type=float, default=1.0)
	parser.add_argument("--epoch_hebb", type=int, default=50)
	parser.add_argument("--rule", type=str, default="hybrid", choices=["hebb", "instar", "oja", "hybrid"])
	parser.add_argument("--a", type=float, default=1.0, help="importance of Hebbian learning")
	parser.add_argument("--b", type=float, default=-1.0, help="importance of anti-Hebbian learning")
	parser.add_argument("--amount_data", type=str, default="100%")
	args = parser.parse_args()

	torch.cuda.set_device(int(args.device))
	args.device = torch.device("cuda:{}".format(args.device))

	print("\n[info] arguments")
	for arg in vars(args):
		print("{}: {}".format(arg, getattr(args, arg) or ''))

	timestamp = str(datetime.datetime.now()).replace(':', '-').replace(' ', '_')
	args.timestamp = timestamp

	np.random.seed(args.seed)
	torch.manual_seed(args.seed)
	torch.backends.cudnn.deterministic = True
	""" END argument setting """

	train_loader, partial_train_loader, test_loader = get_dataloader(args, N=None)

	model = PointCloudClassifer(args)
	model.encoder.args = args
	model.classifier.args = args

	model.to(args.device)
	optimizer = optim.Adam(model.parameters(), lr=args.lr)
	scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

	print("\n[info] Model named parameters")
	for name, params in model.named_parameters():
		print("name: ", name, params.shape)
	num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
	print('[Info] Number of Trainable Parameters: {}\n'.format(num_params))

	train(model, train_loader, partial_train_loader, test_loader, optimizer, scheduler, args)

if __name__ == '__main__':
	main()
