import torch
import sys
import os
import torch.nn as nn
import matplotlib.pyplot as plt
torch.manual_seed(43)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
run_id = f"MULTI_epochs_{int(sys.argv[1])}"
os.makedirs(run_id, exist_ok=True)

# Our space is 4D - (x1,x2,x3,x4)
def function(x1,x2,x3):  
	x4 = torch.sin(x1*x2/5) + torch.cos(x3/5)**2 + x1*x2*x3
	return x4

lossy =[]
hidden_dim = 64

# Define the first vector
p_vector0 = torch.normal(0, 1, size=(1, 2 * hidden_dim)).to(device)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, hidden_dim // 2)).to(device)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize
 

class Net(nn.Module):
	def __init__(self, device):
		super(Net, self).__init__()
		self.theta = 10
		self.layers = []
		dims = [5, 2*hidden_dim, 2*hidden_dim, hidden_dim, hidden_dim , hidden_dim//2]
		for d in range(0, (len(dims)), 2):	# For three layers
			self.layers += [Layer(dims[d], dims[d + 1], self.theta).to(device)]

	def predict(self, x,y,z):
		cos = nn.CosineSimilarity(dim=1, eps=1e-6)
		n_w = 1000   # take 500 points in last dimension
		x = x.repeat(n_w, 1)
		y = y.repeat(n_w, 1)
		z = z.repeat(n_w, 1)

		w1 = torch.linspace(-30, 30, n_w).view(n_w, 1).to(device)



		x_which = torch.cat((x, y, z, w1), 1)
		goodness_per_label = []
		for label in [0.0, 1.0]:  # we have two labels only 1 and -1 for intol and outtol data
			if x_which.shape[1] == 4:
				x_which = torch.cat((x_which, torch.ones_like((x_which[:, 0].unsqueeze(1))) * label), 1)
			else:
				x_which[:, -1] = torch.ones_like((x_which[:, 0])) * label
			goodness = []
			for k, layer in enumerate(self.layers):
				if k == 0:
					#x_which = 500,4 (as expected)
					g = layer(x_which, k)
					goodness += [cos(g, p_vector0.repeat(g.shape[0], 1))]

				if k == 1:
					g = layer(g, k)
					goodness += [cos(g, p_vector1.repeat(g.shape[0], 1))]

				if k == 2:
					g = layer(g, k)
					goodness += [cos(g, p_vector2.repeat(g.shape[0], 1))]
			goodness_per_label += [sum(goodness).unsqueeze(1)]
		goodness_per_label = torch.cat(goodness_per_label, 1)  # shape= (n_y,2)
		mask = goodness_per_label[:, 1] < goodness_per_label[:, 0]
		w1.flatten()
		return w1[mask]

	def forward(self, x):
		# To be used during inference only, while training, do layer by layer
		for layer in self.layers:
			x = layer(x)
		return x

	def train(self, dataloader, n):
		k = 0
		for i, (x, y) in enumerate(dataloader):
			x, y = x.to(device), y.to(device)
			# x is tuple of x and y coordinates and y is label for real or fake data
			x_neg = torch.cat(
				(x[: y.shape[0]//2], y[: y.shape[0]//2].unsqueeze(1)), 1
			)  # positive data
			x_pos = torch.cat(
				(x[y.shape[0]//2 :], y[y.shape[0]//2 :].unsqueeze(1)), 1
			)  # negative data
			# x_neg and x_pos are inverted

			h_pos, h_neg = x_pos, x_neg
			for layer in self.layers:
				h_pos, h_neg, loss = layer.train(h_pos, h_neg, k)
				k += 1
				print("Layer", k, "Loss", loss)


class Layer(nn.Module):
	def __init__(self, in_features, out_features, theta):
		super(Layer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.theta = theta
		self.layer = nn.Linear(in_features, out_features)
		self.main = nn.Sequential(self.layer, nn.GELU())
		self.layer_epochs = n_epochs
		self.opt = torch.optim.Adam(self.layer.parameters(), lr=0.001)

	def forward(self, x, k):
		x = x.view(-1, self.in_features)
		return self.main(x)**2

	def goodness(self, x_pos, x_neg, k):
		h_pos = self.forward(x_pos, k)	# positive always coz of abs
		h_neg = self.forward(x_neg, k)	# negative always coz of abs
		
		cos = nn.CosineSimilarity()

		if k == 0:
			g_pos = cos(h_pos, p_vector0.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector0.repeat(h_neg.shape[0], 1))

		elif k == 1:
			g_pos = cos(h_pos, p_vector1.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector1.repeat(h_neg.shape[0], 1))

		elif k == 2:
			g_pos = cos(h_pos, p_vector2.repeat(h_pos.shape[0], 1))
			g_neg = cos(h_neg, p_vector2.repeat(h_neg.shape[0], 1))


		return g_pos, g_neg

	def train(self, x_pos, x_neg, k):
		self.running_loss = 0.0

		for i in range(self.layer_epochs):
			g_pos, g_neg = self.goodness(x_pos, x_neg, k)
			delta = g_pos - g_neg
			
			loss = (torch.log(1 + torch.exp(-self.theta * delta))).mean()

			self.opt.zero_grad()
			loss.backward()
			self.opt.step()
			self.running_loss += loss.item()
			lossy.append(self.running_loss / self.layer_epochs)
		return (
			self.forward(x_pos, k).detach(),
			self.forward(x_neg, k).detach(),
			self.running_loss / self.layer_epochs,
		)


def weights_init(m):
	if isinstance(m, nn.Linear):
		nn.init.xavier_normal_(m.weight)
		nn.init.zeros_(m.bias)


def create_4d_dataset(x_range, y_range, z_range, N, tol, n_intol, n_outtol):
	X = torch.linspace(x_range[0], x_range[1], N)#.repeat(n_intol,1) # N x N
	Y = torch.linspace(y_range[0], y_range[1], N)#.repeat(n_intol,1) # N x N
	Z = torch.linspace(z_range[0],z_range[1],N)
	X, Y, Z = torch.meshgrid(X,Y,Z,indexing='ij')
	W = function(X,Y,Z)
	w_in_tol = W.repeat(n_intol,1,1) + torch.randn(W.repeat(n_intol,1,1).shape) * tol
	in_tol_points = torch.cat((X.repeat(n_intol,1,1),Y.repeat(n_intol,1,1),Z.repeat(n_intol,1,1),w_in_tol)).view(4,-1) # (4,100)

	ceiling = w_in_tol.max()
	floor = w_in_tol.min()

	w_out_tol = []
	W_flat = W.flatten()
	for i in range(len(W_flat)):
		w_out_tol.append(torch.cat((torch.linspace(W_flat[i]+tol,ceiling,n_outtol//2),torch.linspace(floor,W_flat[i]-tol,n_outtol//2))))
	w_out_tol = torch.stack(w_out_tol,dim=1)
	out_tol_points = torch.cat((X.repeat(n_outtol,1,1),Y.repeat(n_outtol,1,1),Z.repeat(n_outtol,1,1),w_out_tol.reshape(X.repeat(n_outtol,1,1).shape))).view(4,-1)

	return in_tol_points,out_tol_points

if __name__ == '__main__':
	N = 25
	n_epochs= int(sys.argv[1]) # Take epochs from the first argument!
	x_range = (-3,3)
	y_range = (-3,3)
	z_range = (-3,3)
	tol = 2
	n_intol = 30
	n_outtol = 50


	def append_label(tensor, value):
		label_row = torch.full((1, tensor.shape[1]), value)	
		return torch.cat((tensor, label_row), dim=0)

	# Create dataset
	in_tol_points, out_tol_points = create_4d_dataset(x_range, y_range, z_range, N, tol, n_intol, n_outtol)
	

	# Generate positive and negative data
	positive_data = torch.cat((append_label(in_tol_points, 1), append_label(out_tol_points, 0)), dim=1)
	negative_data = torch.cat((append_label(in_tol_points, 0), append_label(out_tol_points, 1)), dim=1)
	dataset = torch.cat((positive_data, negative_data), 1)
	dataset = torch.utils.data.TensorDataset(dataset[:4, :].T, dataset[4, :])
	len_data = len(dataset)
	dataloader = torch.utils.data.DataLoader(dataset, batch_size=len_data, shuffle=False)
	model= Net(device)
	
	PATH = f'shaded_{run_id}/4d_model_new_fun2_{n_epochs}.pth'


	if os.path.exists(PATH):
		print("Model found. Loading model...")
		model = torch.load(PATH,weights_only=False)
	else:
		print("Model not found. Training model...")
		model.train(dataloader, N)
		torch.save(model, PATH)
		print("Training done and model saved.")

	
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
sns.set(style="whitegrid")



# Store data and MSEs
plot_data = []
mse_data = {}

def evaluate_and_store(label, x1, x2, x3):
	x4_true = function(x1, x2, x3)
	x4_pred = torch.zeros_like(x4_true)
	std = torch.zeros_like(x4_true)
	
	for j in range(len(x1)):
		output = model.predict(x1[j], x2[j], x3[j])
		x4_pred[j] = output.mean()
		std[j] = output.std()

	# Compute MSE
	mse = ((x4_true - x4_pred) ** 2).mean().item()
	mse_data[label] = mse

	t_vals = x1.detach().cpu().numpy()
	true_vals = x4_true.detach().cpu().numpy()
	pred_vals = x4_pred.detach().cpu().numpy()
	std_vals = std.detach().cpu().numpy()
	
	for t, true, pred, s in zip(t_vals, true_vals, pred_vals, std_vals):
		plot_data.append({'t': t, 'value': true, 'type': 'True', 'diagonal': label})
		plot_data.append({'t': t, 'value': pred, 'std': s, 'type': 'Pred', 'diagonal': label})


# Prepare t
t = torch.linspace(-3, 3, 50).to(device)
# Define all 10 diagonals
evaluate_and_store("x1 = x2 = x3", t, t, t)
evaluate_and_store("x1 = x2 = -x3", t, t, -t)
evaluate_and_store("-x1 = x2 = x3", -t, t, t)
evaluate_and_store("x1 = -x2 = x3", t, -t, t)
evaluate_and_store("x1 = x2, x3 = 3", t, t, torch.ones_like(t)*3)
evaluate_and_store("x1 = -x2, x3 = -3", t, -t, torch.ones_like(t)*-3)
evaluate_and_store("x1 = x3, x2 = 3", t, torch.ones_like(t)*3, t)
evaluate_and_store("x1 = -x3, x2 = -3", t, torch.ones_like(t)*-3, -t)

# Convert to DataFrame
df = pd.DataFrame(plot_data)
unique_diagonals = list(df['diagonal'].unique())

fig, axes = plt.subplots(2, 4, figsize=(18, 12))
axes = axes.flatten()

line_styles = {
	'True': {'linestyle': '--', 'color': 'orange', 'label': 'True'},
	'Pred': {'linestyle': '-', 'color': 'blue', 'label': 'Predicted'}
}

for i, diag in enumerate(unique_diagonals):
	ax = axes[i]
	diag_df = df[df['diagonal'] == diag]
	
	for typ in ['True', 'Pred']:
		style = line_styles[typ]
		sub_df = diag_df[diag_df['type'] == typ]
		ax.plot(sub_df['t'], sub_df['value'], 
				linestyle=style['linestyle'], 
				color=style['color'], 
				label=style['label'] if i == 0 else None)  # Only label once
		if typ == 'True':
			sub_df_sorted = sub_df.sort_values('t')
			evenly_spaced_indices = torch.linspace(0, len(sub_df_sorted) - 1, 25, dtype=int)
			scatter_points = sub_df_sorted.iloc[evenly_spaced_indices]
			ax.scatter(scatter_points['t'], scatter_points['value'],
					   color='red', marker='x', s=40, label='Trai' if i == 0 else None)

		if typ == 'Pred':
			if 'std' in sub_df:  # make sure std exists
				ax.fill_between(sub_df['t'],
								sub_df['value'] - 2*sub_df['std'],
								sub_df['value'] + 2*sub_df['std'],
								color='grey', alpha=0.2)
				
	ax.set_title(f"{diag}\nMSE: {mse_data[diag]:.4f}", fontsize=10)
	ax.set_xlabel("t")
	ax.set_ylabel("x4")
	ax.grid(True)

for j in range(len(unique_diagonals), len(axes)):
	fig.delaxes(axes[j])

handles = [plt.Line2D([], [], **line_styles['True']),
		   plt.Line2D([], [], **line_styles['Pred'])]
labels = ['True', 'Predicted']
fig.legend(handles, labels, loc='upper right', ncol=2, fontsize=12)

fig.suptitle(f"3D Regression - Epochs: {n_epochs}", fontsize=16)

plt.tight_layout()  # Adjust to make room for legend and title

csv_path = os.path.join(run_id, f"shaded_{run_id}_fun2.csv")
pdf_path = os.path.join(run_id, f"shaded_{run_id}_fun2.pdf")

plt.savefig(pdf_path)
df.to_csv(csv_path, index=False)

# plt.show() intentionally removed
