import numpy as np
import csv
import sys
import glob
import torch
from utils import get_non_pad_mask
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
from utils import optimizer_Hebbian, ChamferDistance
from torch.autograd import Variable
import time
from tqdm import tqdm 
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


def train_epoch(model, train_loader, optimizer, args):
	# for supervised model

	epoch_loss = 0
	epoch_correct = 0
	epoch_count = 0
	epoch_sample = 0
	CELoss = nn.CrossEntropyLoss()

	model.train()
	for batch_idx, batch in enumerate(tqdm(train_loader, mininterval=2, ncols=70, desc=' (Training) ', leave=False)):
		""" prepare data """
		point = batch[0].to(args.device)
		label = batch[1].to(args.device)
		object_name = np.array(batch[2])

		""" forward """
		batch_size = point.shape[0]
		optimizer.zero_grad()

		score = model(point)
		loss = CELoss(score, label)
		pred = torch.max(score.clone(), 1)[1]
		epoch_correct += (pred == label).sum().item()

		loss.backward()
		optimizer.step()

		epoch_loss += loss.item()
		epoch_count += 1
		epoch_sample += batch_size

	return epoch_loss / epoch_count, epoch_correct / epoch_sample


def eval_epoch(model, test_data, args):
	# for supervised model

	epoch_loss = 0
	epoch_correct = 0
	epoch_count = 0
	epoch_sample = 0
	CELoss = nn.CrossEntropyLoss()

	model.eval()
	with torch.no_grad():
		for batch_idx, batch in enumerate(tqdm(test_data, mininterval=2, ncols=70, desc=' (Test) ', leave=False)):
			""" prepare data """
			point = batch[0].to(args.device)
			label = batch[1].to(args.device)
			object_name = np.array(batch[2])

			""" forward """
			batch_size = point.shape[0]

			score = model(point)
			loss = CELoss(score, label)

			pred = torch.max(score.clone(), 1)[1]
			epoch_correct += (pred == label).sum().item()

			epoch_loss += loss.item()
			epoch_count += 1
			epoch_sample += batch_size

	return epoch_loss / epoch_count, epoch_correct / epoch_sample


def train_epoch_hebb(encoder, decoder, train_loader, optimizer, args, epoch, decoder_train, layer=-1):
	# for unsupervised learning

	if decoder_train==False:
		lr_hebb0 = args.lr_hebb # learning rate
		lr_hebb = lr_hebb0 * (1 - epoch / args.epoch_hebb)

	epoch_loss = 0
	epoch_correct = 0
	epoch_count = 0
	epoch_sample = 0
	epoch_variance = [0, 0, 0]
	CELoss = nn.CrossEntropyLoss()

	encoder.train()
	decoder.train()
	for batch_idx, batch in enumerate(tqdm(train_loader, mininterval=2, ncols=70, desc=' (Training) ', leave=False)):
		""" prepare data """
		point = batch[0].to(args.device)
		label = batch[1].to(args.device)
		object_name = np.array(batch[2])

		""" forward """
		batch_size = point.shape[0]
		optimizer.zero_grad()

		encoded_vectors, hidden_vectors, hidden_vectors_k1, variance = encoder(point, train=True)
		score = decoder(encoded_vectors)
		pred = torch.max(score.clone(), 1)[1]
		loss = CELoss(score, label)
		epoch_correct += (pred == label).sum().item()

		if decoder_train==False:
			optimizer_Hebbian(encoder, hidden_vectors, hidden_vectors_k1, lr_hebb, layer, args)

		epoch_loss += loss.item()
		epoch_count += 1
		epoch_sample += batch_size

		epoch_variance[0] += variance[0]
		epoch_variance[1] += variance[1]
		epoch_variance[2] += variance[2]
		
		if decoder_train == True:
			loss.backward()
			optimizer.step()

	epoch_variance[0] = epoch_variance[0] / epoch_sample
	epoch_variance[1] = epoch_variance[1] / epoch_sample
	epoch_variance[2] = epoch_variance[2] / epoch_sample

	return epoch_loss / epoch_count, epoch_correct / epoch_sample, epoch_variance


def eval_epoch_hebb(encoder, decoder, test_data, args):
	# for unsupervised learning

	epoch_loss = 0
	epoch_count = 0
	epoch_correct = 0
	epoch_sample = 0
	epoch_variance = [0, 0, 0]
	CELoss = nn.CrossEntropyLoss()

	encoder.eval()
	decoder.eval()
	with torch.no_grad():
		for batch_idx, batch in enumerate(tqdm(test_data, mininterval=2, ncols=70, desc=' (Test) ', leave=False)):
			""" prepare data """
			point = batch[0].to(args.device)
			label = batch[1].to(args.device)
			object_name = np.array(batch[2])

			""" forward """
			batch_size = point.shape[0]

			encoded_vectors, hidden_vectors, hidden_vectors_k1, variance = encoder(point, train=False)
			score = decoder(encoded_vectors)
			pred = torch.max(score.clone(), 1)[1]
			loss = CELoss(score, label)
			epoch_correct += (pred == label).sum().item()

			epoch_loss += loss.item()
			epoch_count += 1
			epoch_sample += batch_size

			epoch_variance[0] += variance[0]
			epoch_variance[1] += variance[1]
			epoch_variance[2] += variance[2]

	epoch_variance[0] = epoch_variance[0] / epoch_sample
	epoch_variance[1] = epoch_variance[1] / epoch_sample
	epoch_variance[2] = epoch_variance[2] / epoch_sample

	return epoch_loss / epoch_count, epoch_correct / epoch_sample


def train(model, train_loader, partial_train_loader, test_loader, optimizer, scheduler, args):

	encoder = model.encoder
	decoder = model.classifier

	for name, params in model.named_parameters():
		if 'encoder' in name:
			params.requires_grad = args.encoder_train
			if args.model == "unsupervised":
				params.requires_grad = False
			print(name, ": requires_grad =", params.requires_grad)
		if 'classifier' in name:
			params.requires_grad = True
			print(name, ": requires_grad =", params.requires_grad)
	print("")


	if args.model == "unsupervised":
		print("Start unsuservised learning for encoder...")
		max_variance = 0

		# train all layers together
		for epoch in range(args.epoch_hebb):
			torch.cuda.empty_cache()
			# train encoder
			start = time.time()
			if partial_train_loader != None: # partial training data for encoder
				train_loss, train_accuracy, variance = train_epoch_hebb(encoder, decoder, partial_train_loader, optimizer, args, epoch, decoder_train=False)
			else: # full training data for encoder
				train_loss, train_accuracy, variance = train_epoch_hebb(encoder, decoder, train_loader, optimizer, args, epoch, decoder_train=False)

			print(epoch, "layers variance: ", float(variance[0]), float(variance[1]), float(variance[2]))
		
			if float(variance[2]) > max_variance:
				max_variance = float(variance[2])
				#torch.save(model, "./save/unsupervised_k1_encoder_modelnet10_1024.pt")
			#torch.save(model, './save/unsupervised_k1_encoder_modelnet10_1024.pt')

		#input("here")

		print("Finished unsupervised learning for encoder!")
		print("Start supervised learning for decoder...")

		min_loss = 1000
		max_accuracy = 0
		for epoch in range(args.epoch):
			torch.cuda.empty_cache()
			# training
			start = time.time()
			if partial_train_loader != None:
				# first, use partial training data for decoder
				for name, params in model.named_parameters():
					if 'encoder' in name:
						params.requires_grad = False
				train_loss, train_accuracy, distance = train_epoch_hebb(encoder, decoder, partial_train_loader, optimizer, args, epoch, decoder_train=True)
				for name, params in model.named_parameters():
					if 'encoder' in name:
						params.requires_grad = False
				print('(Partial Training Epoch: {epoch: d}) loss: {loss: e}, accuracy: {accuracy:.4f}, elapse: {elapse:3.3f} min' \
					.format(epoch=epoch, loss=train_loss, accuracy=train_accuracy, elapse=(time.time() - start) / 60))

			# second, retrain the decoder with rest training data
			train_loss, train_accuracy, distance = train_epoch_hebb(encoder, decoder, train_loader, optimizer, args, epoch, decoder_train=True)
			print('(Training Epoch: {epoch: d}) loss: {loss: e}, accuracy: {accuracy:.4f}, elapse: {elapse:3.3f} min' \
				.format(epoch=epoch, loss=train_loss, accuracy=train_accuracy, elapse=(time.time() - start) / 60))

			# test
			start = time.time()
			test_loss, test_accuracy = eval_epoch_hebb(encoder, decoder, test_loader, args)
			print('(Test Epoch: {epoch: d}) loss: {loss: e}, accuracy: {accuracy:.4f}, elapse: {elapse:3.3f} min' \
				.format(epoch=epoch, loss=test_loss, accuracy=test_accuracy, elapse=(time.time() - start) / 60))

			if test_loss < min_loss:
				min_loss = test_loss
				#torch.save(model, './save/mnist/unsupervised_k1_modelnet10_1024.pt')

			if test_accuracy > max_accuracy:
				max_accuracy = test_accuracy

		print("min loss: ", min_loss, "max accuracy: ", max_accuracy)

	elif args.model == "supervised":
		min_loss = 1000
		max_accuracy = 0
		for epoch in range(args.epoch):
			# train
			start = time.time()
			train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, args)
			print('(Training Epoch: {epoch: d}) loss: {loss: e}, accuracy: {accuracy:.4f}, elapse: {elapse:3.3f} min' \
				.format(epoch=epoch, loss=train_loss, accuracy=train_accuracy, elapse=(time.time() - start) / 60))

			# test
			start = time.time()
			test_loss, test_accuracy = eval_epoch(model, test_loader, args)
			print('(Test Epoch: {epoch: d}) loss: {loss: e}, accuracy: {accuracy:.4f}, elapse: {elapse:3.3f} min' \
				.format(epoch=epoch, loss=test_loss, accuracy=test_accuracy, elapse=(time.time() - start) / 60))

			if test_loss < min_loss:
				min_loss = test_loss
				#torch.save(model, "./save/supervised_modelnet10_1024.pt")

			if test_accuracy > max_accuracy:
				max_accuracy = test_accuracy

		print("min loss: ", min_loss, "max accuracy: ", max_accuracy)
