import torch
import torch.nn.functional as F
import yaml
import time
import matplotlib.pyplot as plt

def optimizer_Hebbian(model, hidden_vectors, hidden_vectors_k1, lr_hebb, layer, args):

	hidden_idx = 1
	hidden_idx_bias = 1

	for name, params in model.named_parameters():	
		if 'weight' in name:
			if hidden_idx < len(hidden_vectors) and hidden_vectors[hidden_idx] is not None:
				if (layer == -1) or (layer == hidden_idx):
					"""
					delta w ~ lr * input * output
					h_prior: input
					h: output
					h_k1: output with WTA
					"""
					batch = hidden_vectors[0].shape[0]
					num_point = hidden_vectors[0].shape[1]

					h_prior = hidden_vectors[hidden_idx-1] # (batch, num_point, input)
					h = hidden_vectors[hidden_idx] # (batch, num_point, output)
					h_k1 = hidden_vectors_k1[hidden_idx] # (batch, num_point, output)

					if args.rule == "hybrid":
						# memory efficient implementation
						dw1 = torch.sum(torch.matmul(h_k1.permute(0,2,1), h_prior), dim=0)
						dw2 = torch.sum(h_k1.unsqueeze(-1), dim=[0,1]) * params
						dw = (dw1 - dw2) / (batch * num_point)
					else:
						w = params.unsqueeze(0).unsqueeze(0).repeat(batch, num_point, 1, 1).detach()
						x = h_prior.unsqueeze(-2).repeat(1, 1, w.shape[-2], 1).detach()
						if args.rule == "hebb":
							dw = torch.mean(h.unsqueeze(-1) * x, dim=[0,1])
						if args.rule == "instar":
							dw = torch.mean(h.unsqueeze(-1) * (x - w), dim=[0,1])
						if args.rule == "oja":
							dw = torch.mean(h.unsqueeze(-1) * (x - h.unsqueeze(-1) * w), dim=[0,1])

					params.data += lr_hebb * dw.detach()

				hidden_idx += 1


def ChamferDistance(point_rec, point):
	point_rec_long = point_rec[:, :, None, :]
	point_long = point[:, None, :, :]
	point_wise_error = torch.norm(point_long - point_rec_long, dim=-1) ** 2
	chamfer_distance = torch.min(point_wise_error, dim=1)[0] + torch.min(point_wise_error, dim=2)[0]

	return chamfer_distance.mean()