import torch
import os
import sys
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
torch.manual_seed(43)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
run_id = f"PERIODIC_n_{int(sys.argv[1])}_epochs_{int(sys.argv[2])}"

def function(x): 
	return torch.sin(2* torch.pi*x) + 1 
lossy =[]
n_epochs = int(sys.argv[2])
hidden_dim = 128  # Example value, adjust as needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the first vector
p_vector0 = torch.normal(0, 1, size=(1,  hidden_dim)).to(device)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, hidden_dim*2)).to(device)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize

p_vector3 = torch.normal(0, 1, size=(1, hidden_dim // 2)).to(device)
p_vector3 = p_vector3 / (p_vector3.norm(2, 1, keepdim=True) + 1e-4)  


class Net(nn.Module):
	def __init__(self, device):
		super(Net, self).__init__()
		self.theta = 10
		self.layers = []
		dims = [3, hidden_dim, hidden_dim, hidden_dim*2, hidden_dim*2, hidden_dim, hidden_dim , hidden_dim//2]
		for d in range(0, (len(dims)), 2):	# For three layers
			self.layers += [Layer(dims[d], dims[d + 1], self.theta).to(device)]

	def predict(self, x,y_min,y_max):
		cos = nn.CosineSimilarity(dim=1, eps=1e-6)
		n_y = 1000  # take 500 points in y direction at given x coordinate
		x = x.repeat(n_y, 1)
		y = (
			# torch.linspace(y_min.flatten()[0],y_max.flatten()[0], n_y).view(n_y, 1).to(device)
			torch.linspace(-0.5,2, n_y).view(n_y, 1).to(device)
		)  # generate 50 random y coordinates
		# some of these y coordinates are real and some are fake
		# we need to predict which are real and which are fake
		x_which = torch.cat((x, y), 1)	
		# It will give n_y points in form of (x,y) so shape [50,2]
		h = x_which
		goodness_per_label = []
		for label in [
			0.0,
			1.0,
		]:	# we have two labels only 1 and -1 for intol and outtol data

			if x_which.shape[1] == 2:
				x_which = torch.cat(
					(x_which, torch.ones_like((x_which[:, 0].unsqueeze(1))) * label), 1
				)
			else:
				x_which[:, -1] = torch.ones_like((x_which[:, 0])) * label

			goodness = []
			for k, layer in enumerate(self.layers):
				# print(k)
				if k == 0:
					# first layer
					g = layer(x_which, k)
					goodness += [cos(g, p_vector0.repeat(g.shape[0], 1))]
					# print("Goodness", k, goodness[0].shape)
				if k == 1:
					# second layer
					g = layer(g, k)
					goodness += [cos(g, p_vector1.repeat(g.shape[0], 1))]
				if k == 2:
					# third layer
					g = layer(g, k)
					goodness += [cos(g, p_vector2.repeat(g.shape[0], 1))]

				if k == 3:
					# fourth layer
					g = layer(g,k)
					goodness += [cos(g, p_vector3.repeat(g.shape[0], 1))]

			goodness_per_label += [sum(goodness).unsqueeze(1)]

		goodness_per_label = torch.cat(goodness_per_label, 1)  # shape= (n_y,2)
		mask = goodness_per_label[:, 1] < goodness_per_label[:, 0]
		return y[mask]

	def forward(self, x):
		# To be used during inference only, while training, do layer by layer
		for layer in self.layers:
			x = layer(x)
		return x

	def train(self, dataloader, n):
		k = 0
		for i, (x, y) in enumerate(dataloader):
			x, y = x.to(device), y.to(device)
			# x is tuple of x and y coordinates and y is label for real or fake data
			x_neg = torch.cat(
				(x[: n * 20], y[: n * 20].unsqueeze(1)), 1
			)  # positive data
			x_pos = torch.cat(
				(x[n * 20 :], y[n * 20 :].unsqueeze(1)), 1
			)  # negative data
			# x_neg and x_pos are inverted

			h_pos, h_neg = x_pos, x_neg
			for layer in self.layers:
				h_pos, h_neg, loss = layer.train(h_pos, h_neg, k)
				k += 1
				print("Layer", k, "Loss", loss)


class Layer(nn.Module):
	def __init__(self, in_features, out_features, theta):
		super(Layer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.theta = theta
		self.layer = nn.Linear(in_features, out_features)
		self.main = nn.Sequential(self.layer, nn.GELU())
		self.layer_epochs = n_epochs
		self.opt = torch.optim.Adam(self.layer.parameters(), lr=0.001)

	def forward(self, x, k):
		x = x.view(-1, self.in_features)
		return self.main(x)**2

	def goodness(self, x_pos, x_neg, k):
		h_pos = self.forward(x_pos, k)	# positive always coz of abs
		h_neg = self.forward(x_neg, k)	# negative always coz of abs
		cos = nn.CosineSimilarity()
		if k == 0:
			g_pos = cos(h_pos, p_vector0.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector0.repeat(h_neg.shape[0], 1))

		elif k == 1:
			g_pos = cos(h_pos, p_vector1.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector1.repeat(h_neg.shape[0], 1))

		elif k == 2:
			g_pos = cos(h_pos, p_vector2.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector2.repeat(h_neg.shape[0], 1))

		elif k == 3:
			g_pos = cos(h_pos, p_vector3.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector3.repeat(h_neg.shape[0], 1))
		return g_pos, g_neg

	def train(self, x_pos, x_neg, k):
		self.running_loss = 0.0
		
		for i in range(self.layer_epochs):
			g_pos, g_neg = self.goodness(x_pos, x_neg, k)
			delta = g_pos - g_neg
			
			loss = (torch.log(1 + torch.exp(-self.theta * delta))).mean()

			self.opt.zero_grad()
			loss.backward()
			self.opt.step()
			self.running_loss += loss.item()
			lossy.append(self.running_loss / self.layer_epochs)
		return (
			self.forward(x_pos, k).detach(),
			self.forward(x_neg, k).detach(),
			self.running_loss / self.layer_epochs,
		)


def weights_init(m):
	if isinstance(m, nn.Linear):
		nn.init.xavier_normal_(m.weight)
		nn.init.zeros_(m.bias)


def create_dataset(x, n):
	y_noised = function(x).view(1, n).repeat(10, 1) + torch.randn(10, n) * tol #0.1
	x = x.repeat(10, 1)

	in_tol_pos = torch.cat((x, y_noised, torch.ones_like(x, dtype=torch.float)), 0).view(3, 10, n)
	in_tol_neg = torch.cat((x, y_noised, torch.ones_like(x, dtype=torch.float) * 0.0), 0).view(3, 10, n)

	y_high = function(x[:5]) + tol	# upper tolerance band
	y_low = function(x[:5]) - tol  # lower tolerance band

	# y_above is sample from points above the high tolerance band
	y_above = torch.zeros_like(y_noised[:5])
	y_below = torch.zeros_like(y_noised[:5])

	y_max = y_noised.max().repeat(5, n) + 2
	y_min = y_noised.min().repeat(5, n) - 2

	for i in range(5):
		y_above[i] = (y_max[i] - y_high[i]) * torch.rand(1, n) + y_high[i]
		y_below[i] = (y_low[i] - y_min[i]) * torch.rand(1, n) + y_min[i]
	y_out_tol = torch.cat((y_above, y_below), 0)  # shape [10,n] 10 points above and below tolerance band
	out_tol_pos = torch.cat((x, y_out_tol, torch.ones_like(x, dtype=torch.float) * 0.0), 0).view(3, 10, n)  # -1 os correct label for out of tol data
	out_tol_neg = torch.cat((x, y_out_tol, torch.ones_like(x, dtype=torch.float)), 0).view(3, 10, n)
	positive_data = torch.cat((in_tol_pos, out_tol_pos), 1).flatten(1, 2)
	negative_data = torch.cat((in_tol_neg, out_tol_neg), 1).flatten(1, 2)
	return positive_data, negative_data, y_max,y_min


model = Net(device).to(device)
if __name__ == '__main__':
	n = int(sys.argv[1])	# Number of samples in dataset
	x_low = 0
	x_high = 10

	x = torch.linspace(x_low, x_high, n).view(1, n)
	x_train = x.clone()
	tol = 2e-2
	# tol = 0.005
	theta = 64.0
	pos, neg, y_max, y_min = create_dataset(x, n)
	print(pos.shape)
	print(neg.shape)
	pos = pos.to(device)
	neg = neg.to(device)

	def plot():
		plt.figure(figsize=(6, 6))
		plt.scatter(pos.to("cpu")[0, :], pos.to("cpu")[1, :], c="g", s=3)
		plt.scatter(neg.to("cpu")[0, :], neg.to("cpu")[1, :], c="r", s=3)
		plt.legend(["Positive data", "Negative data"])
		plt.plot(x.to("cpu")[0, :], function(x).to("cpu")[0, :] + 0.2, c="black")
		plt.plot(x.to("cpu")[0, :], function(x).to("cpu")[0, :] - 0.2, c="black")
		plt.title("Training Dataset")
		# plt.show()


	# plot()	# Uncomment to plot the training dataset

	dataset = torch.cat((pos, neg), 1)
	dataset = torch.utils.data.TensorDataset(dataset[:2, :].T, dataset[2, :])
	dataloader = torch.utils.data.DataLoader(dataset, batch_size=2*pos.shape[1], shuffle=False)

	model.train(dataloader, n)
	print("Training Done")

	# x = torch.linspace(x_low, x_high, 100).view(100, 1).to(device)

	
	x = torch.linspace(x_low, x_high, 100).view(100, 1).to(device)
	y = torch.zeros(100).unsqueeze(1).to(device)
	std = torch.zeros(100).unsqueeze(1).to(device)
	y_temp = []
	for i in range(100):
		y[i] = model.predict(x[i],y_min,y_max).mean()
		std[i] = model.predict(x[i],y_min,y_max).std(unbiased=False)
	os.makedirs(run_id, exist_ok=True)
	# Move tensors to CPU (if they're not already)
	x_cpu = x.to("cpu")
	y_cpu = y.to("cpu")
	x_train_cpu = x_train.to("cpu")
	function_x_train_cpu = function(x_train).to("cpu")
	std_cpu = std.to("cpu")

	# Plotting
	plt.figure(figsize=(6, 6))
	plt.plot(x_cpu, y_cpu)  # Predicted curve
	plt.plot(x_train_cpu, function_x_train_cpu, "rx")  # Actual training points
	plt.fill_between(
		x_cpu.squeeze().numpy(),  # ensure x is 1D
    (y_cpu - std_cpu).squeeze().numpy(),  # lower bound
    (y_cpu + std_cpu).squeeze().numpy(),  # upper bound
    color="grey",
    alpha=0.3
	)
	plt.xlabel("x")
	plt.ylabel("y")
	plt.title("Function approximation")
	plt.legend(
		["Predicted Curve", "Actual Training Points"],
        loc="upper right",
        frameon=True,
        framealpha=0.7
    )
	

	# Print MSE
	print("MSE", torch.nn.functional.mse_loss(y_cpu, function(x_cpu)))

	# Save figure as PDF inside run_id folder
	pdf_path = os.path.join(run_id, f"{run_id}.pdf")
	plt.savefig(pdf_path)
	plt.close()

	# Prepare data for CSV
	df = pd.DataFrame({
		"x": x_cpu.numpy().flatten(),
		"y": y_cpu.numpy().flatten(),
		"upper_bound": (y_cpu + std_cpu).numpy().flatten(),
		"lower_bound": (y_cpu - std_cpu).numpy().flatten()
	})

	# Save CSV inside run_id folder
	csv_path = os.path.join(run_id, f"{run_id}.csv")
	df.to_csv(csv_path, index=False)


