import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
start_time = time.time()

torch.manual_seed(43)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def function(x):  # Edit your function here which you want to approximate for regression problem
	return x**2

lossy =[]
n_epochs = 500
hidden_dim = 128  # Example value, adjust as needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the first vector
p_vector0 = torch.normal(0, 1, size=(1, 2 * hidden_dim)).to(device)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, hidden_dim // 2)).to(device)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize


class Net(nn.Module):
	def __init__(self, device):
		super(Net, self).__init__()
		self.theta = 10
		self.layers = []
		dims = [3, 2*hidden_dim, 2*hidden_dim, hidden_dim, hidden_dim , hidden_dim//2]
		for d in range(0, (len(dims)), 2):	# For three layers
			self.layers += [Layer(dims[d], dims[d + 1], self.theta).to(device)]

	def predict(self, x):
		cos = nn.CosineSimilarity(dim=1, eps=1e-6)
		n_y = 1000  # take 500 points in y direction at given x coordinate
		x = x.repeat(n_y, 1)


		######################
		# y_range defined below should give tolerance with actual values.
		# y = torch.linspace(-0.5,1.5, n_y).view(n_y, 1).to(device)  
		######################
		# If range is exactly passing through the exact function value than it will give error.
		y = torch.linspace(0,1, n_y).view(n_y, 1).to(device)  		
		######################
		x_which = torch.cat((x, y), 1)	
		# It will give n_y points in form of (x,y) so shape [50,2]
		h = x_which
		goodness_per_label = []
		for label in [
			0.0,
			1.0,
		]:	# we have two labels only 1 and -1 for intol and outtol data

			if x_which.shape[1] == 2:
				x_which = torch.cat(
					(x_which, torch.ones_like((x_which[:, 0].unsqueeze(1))) * label), 1
				)
			else:
				x_which[:, -1] = torch.ones_like((x_which[:, 0])) * label

			goodness = []
			for k, layer in enumerate(self.layers):
				# print(k)
				if k == 0:
					g = layer(x_which, k)
					goodness += [cos(g, p_vector0.repeat(g.shape[0], 1))]

				if k == 1:
					g = layer(g, k)
					goodness += [cos(g, p_vector1.repeat(g.shape[0], 1))]
				if k == 2:
					g = layer(g, k)
					goodness += [cos(g, p_vector2.repeat(g.shape[0], 1))]
			goodness_per_label += [sum(goodness).unsqueeze(1)]
		goodness_per_label = torch.cat(goodness_per_label, 1)  # shape= (n_y,2)
		mask = goodness_per_label[:, 1] < goodness_per_label[:, 0]
		return y[mask]

	def forward(self, x):
		# To be used during inference only, while training, do layer by layer
		for layer in self.layers:
			x = layer(x)
		return x

	def train(self, dataloader, n):
		k = 0
		for i, (x, y) in enumerate(dataloader):
			x, y = x.to(device), y.to(device)
			# x is tuple of x and y coordinates and y is label for real or fake data
			x_neg = torch.cat(
				(x[: n * 20], y[: n * 20].unsqueeze(1)), 1
			)  # positive data
			x_pos = torch.cat(
				(x[n * 20 :], y[n * 20 :].unsqueeze(1)), 1
			)  # negative data
			# x_neg and x_pos are inverted

			h_pos, h_neg = x_pos, x_neg
			for layer in self.layers:
				h_pos, h_neg, loss = layer.train(h_pos, h_neg, k)
				k += 1
				print("Layer", k, "Loss", loss)


class Layer(nn.Module):
	def __init__(self, in_features, out_features, theta):
		super(Layer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.theta = theta
		self.layer = nn.Linear(in_features, out_features)
		self.main = nn.Sequential(self.layer, nn.GELU())
		self.layer_epochs = n_epochs
		self.opt = torch.optim.Adam(self.layer.parameters(), lr=0.001)

	def forward(self, x, k):
		x = x.view(-1, self.in_features)
		return self.main(x)**2

	def goodness(self, x_pos, x_neg, k):
		h_pos = self.forward(x_pos, k)	# positive always coz of abs
		h_neg = self.forward(x_neg, k)	# negative always coz of abs
		
		cos = nn.CosineSimilarity()

		if k == 0:
			g_pos = cos(h_pos, p_vector0.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector0.repeat(h_neg.shape[0], 1))

		elif k == 1:
			g_pos = cos(h_pos, p_vector1.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector1.repeat(h_neg.shape[0], 1))

		elif k == 2:
			g_pos = cos(h_pos, p_vector2.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector2.repeat(h_neg.shape[0], 1))


		return g_pos, g_neg

	def train(self, x_pos, x_neg, k):
		self.running_loss = 0.0

		for i in range(self.layer_epochs):
			g_pos, g_neg = self.goodness(x_pos, x_neg, k)
			delta = g_pos - g_neg
			
			loss = (torch.log(1 + torch.exp(-self.theta * delta))).mean()

			self.opt.zero_grad()
			loss.backward()
			self.opt.step()
			self.running_loss += loss.item()
			lossy.append(self.running_loss / self.layer_epochs)
		return (
			self.forward(x_pos, k).detach(),
			self.forward(x_neg, k).detach(),
			self.running_loss / self.layer_epochs,
		)


def weights_init(m):
	if isinstance(m, nn.Linear):
		nn.init.xavier_normal_(m.weight)
		nn.init.zeros_(m.bias)


def create_dataset(x, n):
	y_noised = function(x).view(1, n).repeat(10, 1) + torch.randn(10, n) * tol #0.1
	x = x.repeat(10, 1)

	in_tol_pos = torch.cat((x, y_noised, torch.ones_like(x, dtype=torch.float)), 0).view(3, 10, n)
	in_tol_neg = torch.cat((x, y_noised, torch.ones_like(x, dtype=torch.float) * 0.0), 0).view(3, 10, n)

	y_high = function(x[:5]) + tol	# upper tolerance band
	y_low = function(x[:5]) - tol  # lower tolerance band

	# y_above is sample from points above the high tolerance band
	y_above = torch.zeros_like(y_noised[:5])
	y_below = torch.zeros_like(y_noised[:5])

	y_max = y_noised.max().repeat(5, n) + 2
	y_min = y_noised.min().repeat(5, n) - 2

	for i in range(5):
		y_above[i] = (y_max[i] - y_high[i]) * torch.rand(1, n) + y_high[i]
		y_below[i] = (y_low[i] - y_min[i]) * torch.rand(1, n) + y_min[i]

	y_out_tol = torch.cat(
		(y_above, y_below), 0
	)  # shape [10,n] 10 points above and below tolerance band

	out_tol_pos = torch.cat(
		(x, y_out_tol, torch.ones_like(x, dtype=torch.float) * 0.0), 0
	).view(3, 10, n)  # -1 os correct label for out of tol data
	out_tol_neg = torch.cat(
		(x, y_out_tol, torch.ones_like(x, dtype=torch.float)), 0
	).view(3, 10, n)

	positive_data = torch.cat((in_tol_pos, out_tol_pos), 1).flatten(1, 2)
	negative_data = torch.cat((in_tol_neg, out_tol_neg), 1).flatten(1, 2)


	return positive_data, negative_data


model = Net(device).to(device)

n = 20	# Number of samples in dataset
x_low = -1
x_high = 1

x = torch.linspace(x_low, x_high, n).view(1, n)
x_train = x.clone()
tol = 1e-2
# tol = 0.005
theta = 64.0
pos, neg = create_dataset(x, n)
pos = pos.to(device)
neg = neg.to(device)

#
def plot():
	plt.figure(figsize=(6, 6))
	plt.scatter(pos.to("cpu")[0, :], pos.to("cpu")[1, :], c="g", s=3)
	plt.scatter(neg.to("cpu")[0, :], neg.to("cpu")[1, :], c="r", s=3)
	plt.legend(["Positive data", "Negative data"])
	plt.plot(x.to("cpu")[0, :], function(x).to("cpu")[0, :] + 0.2, c="black")
	plt.plot(x.to("cpu")[0, :], function(x).to("cpu")[0, :] - 0.2, c="black")
	plt.title("Training Dataset")
	# plt.show()


# plot()	# Uncomment to plot the training dataset

dataset = torch.cat((pos, neg), 1)
dataset = torch.utils.data.TensorDataset(dataset[:2, :].T, dataset[2, :])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=800, shuffle=False)

model.train(dataloader, n)
print("Training Done")

x = torch.linspace(-1, 1, 100).view(100, 1).to(device)
y = torch.zeros(100).unsqueeze(1).to(device)
std = torch.zeros(100).unsqueeze(1).to(device)
y_temp = []
for i in range(100):
	y[i] = model.predict(x[i]).mean()
	std[i] = model.predict(x[i]).std(unbiased=False)
x = x.to("cpu")
y = y.to("cpu")
std = std.to("cpu")
plt.figure(figsize=(6, 6))
plt.plot(x, y,'b',label="Predicted Curve")	 # Predicted curve
plt.plot(x_train[0].to("cpu"), function(x_train[0]).to("cpu"), "rx",label="Actual Training Points")	# Actual training points
plt.plot(x_train[1:].to("cpu"), function(x_train[1:]).to("cpu"), "rx",label="_nolabel_")	# Actual training points

plt.fill_between(
    x.flatten().to("cpu"),
    (y - std).flatten(),
    (y + std).flatten(),
    color="grey",
    alpha=0.3,  # Adjust transparency as needed
	label="95% Tolerance Band"
)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Function approximation")
# plt.legend(["Predicted Curve", "Actual Training Points","95% Tolerance Band"])
plt.legend(loc='upper center')
print("MSE", torch.nn.functional.mse_loss(y, function(x)))
plt.show()


print("Time taken for computation:", time.time() - start_time)


