import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from utils import BatchPool
import matplotlib.pyplot as plt


def WTA(x, args, train):
	# input : x, output : filtered (WTA) x
	batch_size, num_point, num_neuron = x.shape[0], x.shape[1], x.shape[2]

	_, indices_num_neuron = x.sort(dim=-1, descending=True)
	_, indices_num_point = x.sort(dim=-2, descending=True)

	mask_num_neuron = torch.zeros_like(x) # winner-takes-all along num_neuron dimension
	mask_num_point = torch.zeros_like(x) # winner-takes-all along num_point dimension
	mask_num_neuron_k1 = torch.zeros_like(x)
	mask_num_point_k1 = torch.zeros_like(x)

	for batch_idx in range(batch_size):
		for k_idx in range(args.k):
			mask_num_neuron[batch_idx, torch.arange(num_point), indices_num_neuron[batch_idx, :, k_idx + args.rank_num_neuron]] = args.delta_num_neuron
			mask_num_point[batch_idx, indices_num_point[batch_idx, k_idx + args.rank_num_point, :], torch.arange(num_neuron)] = args.delta_num_point

		mask_num_neuron_k1[batch_idx, torch.arange(num_point), indices_num_neuron[batch_idx, :, 0 + args.rank_num_neuron]] = args.delta_num_neuron
		mask_num_point_k1[batch_idx, indices_num_point[batch_idx, 0 + args.rank_num_point, :], torch.arange(num_neuron)] = args.delta_num_point

	mask_activity = torch.sum(mask_num_neuron, dim=-2).unsqueeze(-2) # [batch_size, num_point -> 1, num_neuron]
	optimal_activity = num_point * torch.tensor(args.k / num_neuron, device=args.device).repeat(batch_size, 1, num_neuron)

	mask_activity = mask_activity - optimal_activity
	mask_activity[mask_activity > 0] = -args.b # default: 1, anti-Hebbian learning: control parameter b
	mask_activity[mask_activity < 0] = -args.a # default: -1, Hebbian learning: control parameter a
	mask_activity = - mask_activity # makes low-active neuron to be Hebbian / high-active neruon to be anti-Hebbian
	mask_activity = mask_activity.repeat(1, num_point, 1)

	mask_num_neuron_k1 = mask_num_neuron_k1 * mask_activity
	mask_num_point_k1 = mask_num_point_k1 * mask_activity

	# x_filtered is not used in learning purpose
	if args.delta_num_neuron == 0 and args.delta_num_point != 0:
		if train == False: # k winners in inference
			x_filtered = x * mask_num_point
		else: # single winner in training
			x_filtered = x * mask_num_point_k1

	elif args.delta_num_neuron != 0 and args.delta_num_point == 0:
		if train == False:
			x_filtered = x * mask_num_neuron
		else:
			x_filtered = x * mask_num_neuron_k1

	elif args.delta_num_neuron != 0 and args.delta_num_point != 0:
		if train == False:
			x_filtered = x * mask_num_neuron * mask_num_point
		else:
			x_filtered = x * mask_num_neuron_k1 * mask_num_point_k1

	return x_filtered, mask_num_neuron, mask_num_point, mask_num_neuron_k1, mask_num_point_k1


class HebbianEncoder(nn.Module):
	def __init__(self, args):
		super().__init__()

		if args.model == "supervised":
			self.fc1 = nn.Linear(3, 64)
			self.fc2 = nn.Linear(64, 128)
			self.fc3 = nn.Linear(128, 1024)

		elif args.model == "unsupervised":
			self.fc1 = nn.Linear(3, 64, bias=False)
			self.fc2 = nn.Linear(64, 128, bias=False)
			self.fc3 = nn.Linear(128, 1024, bias=False)

		self.relu = nn.ReLU()
		self.args = args

	def forward(self, x, train=False):

		if self.args.model == "unsupervised":
			h0, h01, h02, h03 = None, None, None, None # hidden vectors
			h11, h12, h13 = None, None, None
			h0 = x.clone()
			var = [0,0,0]

			# 1st layer
			x = self.fc1(x)
			x = self.relu(x) # [batch, num_point, num_neuron]
			x = x / (torch.max(x, dim=-1, keepdim=True)[0] + 1e-10)

			# 1st var measure
			w = self.fc1.weight.data
			w_norm = torch.sum(torch.square(w), dim=-1).unsqueeze(0).unsqueeze(0).repeat(x.shape[0], x.shape[1], 1)
			x_norm = torch.sum(torch.square(h0), dim=-1).unsqueeze(-1).repeat(1, 1, x.shape[-1])

			# Find a winner using Euclidean distance ('dist' variable)
			dist = - torch.sqrt(torch.abs(w_norm + x_norm - 2 * self.fc1(h0)))
			dist, mask_num_neuron, mask_num_point, mask_num_neuron_k1, mask_num_point_k1 = WTA(dist, self.args, train)
			h01 = mask_num_point_k1

			# Winner Take All
			if self.args.rule == "hebb" or self.args.rule == "instar" or self.args.rule == "oja":
				x = x * mask_num_neuron # with WTA
			else:
				x = x * mask_num_neuron # with WTA
			h11 = x.clone()

			yi_yi = torch.sum(mask_num_neuron * mask_num_neuron, dim=-1) # <y_i, y_i>
			var_first = torch.sum(yi_yi, dim=-1) / mask_num_neuron.shape[1] # sum_{i=1}^{N} (<y_i, y_i>) / N
			mu_y = torch.sum(mask_num_neuron, dim=-2) / mask_num_neuron.shape[1] # sum_{i=1}^{N} (y_i) / N
			var_second = torch.sum(mu_y * mu_y, dim=-1) # <mu_y, mu_y>
			var[0] = torch.sum(var_first - var_second) # batch sum for var = sum_{i=1}^{N} (<y_i, y_i>) / N - <mu_y, mu_y>


			# 2nd layer
			x = self.fc2(x)
			x = self.relu(x)
			x = x / (torch.max(x, dim=-1, keepdim=True)[0] + 1e-10)

			# 2nd distance measure
			w = self.fc2.weight.data
			w_norm = torch.sum(torch.square(w), dim=-1).unsqueeze(0).unsqueeze(0).repeat(x.shape[0], x.shape[1], 1)
			h_norm = torch.sum(torch.square(h11), dim=-1).unsqueeze(-1).repeat(1, 1, x.shape[-1])

			# Find a winner using Euclidean distance ('dist' variable)
			dist = - torch.sqrt(torch.abs(w_norm + h_norm - 2 * self.fc2(h11)))
			dist, mask_num_neuron, mask_num_point, mask_num_neuron_k1, mask_num_point_k1 = WTA(dist, self.args, train)
			h02 = mask_num_point_k1

			# Winner Take All
			if self.args.rule == "hebb" or self.args.rule == "instar" or self.args.rule == "oja":
				x = x * mask_num_neuron # with WTA
			else:
				x = x * mask_num_neuron # with WTA
			h12 = x.clone()

			yi_yi = torch.sum(mask_num_neuron * mask_num_neuron, dim=-1) # <y_i, y_i>
			var_first = torch.sum(yi_yi, dim=-1) / mask_num_neuron.shape[1] # sum_{i=1}^{N} (<y_i, y_i>) / N
			mu_y = torch.sum(mask_num_neuron, dim=-2) / mask_num_neuron.shape[1] # sum_{i=1}^{N} (y_i) / N
			var_second = torch.sum(mu_y * mu_y, dim=-1) # <mu_y, mu_y>
			var[1] = torch.sum(var_first - var_second) # batch sum for var = sum_{i=1}^{N} (<y_i, y_i>) / N - <mu_y, mu_y>


			# 3rd layer
			x = self.fc3(x)
			x = self.relu(x)
			x = x / (torch.max(x, dim=-1, keepdim=True)[0] + 1e-10)

			# 3rd distance measure
			w = self.fc3.weight.data
			w_norm = torch.sum(torch.square(w), dim=-1).unsqueeze(0).unsqueeze(0).repeat(x.shape[0], x.shape[1], 1)
			h_norm = torch.sum(torch.square(h12), dim=-1).unsqueeze(-1).repeat(1, 1, x.shape[-1])

			# Find a winner using Euclidean distance ('dist' variable)
			dist = - torch.sqrt(torch.abs(w_norm + h_norm - 2 * self.fc3(h12)))
			dist, mask_num_neuron, mask_num_point, mask_num_neuron_k1, mask_num_point_k1 = WTA(dist, self.args, train)
			h03 = mask_num_point_k1

			# Winner Take All
			if self.args.rule == "hebb" or self.args.rule == "instar" or self.args.rule == "oja":
				x = x * mask_num_neuron # with WTA
			else:
				x = x * mask_num_neuron # with WTA
			h13 = x.clone()

			yi_yi = torch.sum(mask_num_neuron * mask_num_neuron, dim=-1) # <y_i, y_i>
			var_first = torch.sum(yi_yi, dim=-1) / mask_num_neuron.shape[1] # sum_{i=1}^{N} (<y_i, y_i>) / N
			mu_y = torch.sum(mask_num_neuron, dim=-2) / mask_num_neuron.shape[1] # sum_{i=1}^{N} (y_i) / N
			var_second = torch.sum(mu_y * mu_y, dim=-1) # <mu_y, mu_y>
			var[2] = torch.sum(var_first - var_second) # batch sum for var = sum_{i=1}^{N} (<y_i, y_i>) / N - <mu_y, mu_y>


			x = torch.max(x, dim=-2)[0].detach() # max pooling

			return x, [h0, h11, h12, h13], [h0, h01, h02, h03], var


		elif self.args.model == "supervised":
			x = self.fc1(x)
			x = self.relu(x)

			x = self.fc2(x)
			x = self.relu(x)

			x = self.fc3(x)
			x = self.relu(x)

			x = torch.max(x, dim=-2)[0] # max pooling

			return x


class MLPClassifier(nn.Module):
	def __init__(self, args):
		super().__init__()
		self.args = args
		self.fc1 = nn.Linear(1024, 512)
		self.fc2 = nn.Linear(512, 256)

		if self.args.dataset == "ModelNet10" or self.args.dataset == "MNIST":
			self.fc3 = nn.Linear(256, 10)
		elif self.args.dataset == "ModelNet40":
			self.fc3 = nn.Linear(256, 40)

		self.ln1 = nn.LayerNorm([512])
		self.ln2 = nn.LayerNorm([256])

		self.relu = nn.ReLU()
		self.softmax = nn.Softmax(dim=-1)

	def forward(self, z):
		x = self.fc1(z)
		x = self.relu(self.ln1(x))
		x = self.fc2(x)
		x = self.relu(self.ln2(x))
		x = self.fc3(x)
		x = self.softmax(x)
		return x


class PointCloudClassifer(nn.Module):
	def __init__(self, args):
		super().__init__()
		self.encoder = HebbianEncoder(args)
		self.classifier = MLPClassifier(args)
		self.args = args

	def forward(self, x):
		if self.args.model == "unsupervised":
			z, hidden_vectors, hidden_vectors_k1, distance = self.encoder(x)

		elif self.args.model == "supervised":
			z = self.encoder(x)

		score = self.classifier(z)
		return score
