import torch
import sys
import os
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
################################################
# N - Dimensional Forward Forward Function Regression code for z = f(x,y,z) 
# Can be scaled to any dimensional data.
################################################

SEED = 43
torch.manual_seed(SEED)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 512
FUN_DIM = 3   #### REPLACCE WITH N - DIMENSION OF THE FUNCTION OR NUMBER OF FEATURES>
RUN_ID = f"{FUN_DIM}D_Regression"
N = 100
N_EPOCHS = int(sys.argv[1]) # Take epochs from the first argument!
TOL = 0.001
N_OUTTOL = 50
N_INTOL = 30
N_TRIAL = 300
THETA = 10 # hyperparameter


def function_3d(x1, x2, x3):
	return torch.exp(x1**2/5) * torch.sin (x2*x3/5) + torch.exp(x2**2/5) * torch.sin (x1*x3/5) +  torch.exp(x3**2/5) * torch.sin (x1*x2/5)
	# return x1**2 + x2**2 + x3**2

lossy =[] # to append loss there

# Define the first vector
p_vector0 = torch.normal(0, 1, size=(1, 2 * HIDDEN_DIM)).to(DEVICE)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, HIDDEN_DIM)).to(DEVICE)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, HIDDEN_DIM // 2)).to(DEVICE)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize
 
p_vector = [p_vector0, p_vector1, p_vector2]  # List of all vectors

class Net(nn.Module):
	def __init__(self, DEVICE):
		super(Net, self).__init__()
		self.theta = THETA
		self.layers = []
		dims = [FUN_DIM + 1 + 1, 2*HIDDEN_DIM, 2*HIDDEN_DIM, HIDDEN_DIM, HIDDEN_DIM , HIDDEN_DIM//2]
		for d in range(0, (len(dims)), 2):	# For three layers
			self.layers += [Layer(dims[d], dims[d + 1], self.theta).to(DEVICE)]

	def predict(self, xi,y_min,y_max):
		"""
		input shape = (n_datapoints, features) correct datapoints.
		output shape = (n_datapoints, 1) 
		"""
		xi =  torch.tensor(xi.values, dtype=torch.float32)
		y_min, y_max = y_min - 0.05 * abs(y_max - y_min) , y_max + 0.05 * abs(y_max - y_min)
		print("y_min:", y_min, "y_max:", y_max)
		batch_size = xi.shape[0]
		n_features = xi.shape[1]
		cos = nn.CosineSimilarity(dim=1, eps=1e-6)
		y_trial = torch.linspace(y_min,y_max,N_TRIAL)
		xi = xi.repeat(1,N_TRIAL).to(DEVICE)
		xi = xi.flatten()
		xi = xi.reshape(-1,n_features)
		y_trial = y_trial.unsqueeze(-1).repeat(batch_size,1).to(DEVICE)
		x_which = torch.cat([xi, y_trial], dim=1)
		# x_which are all the trial datapoints with dimension (datapoints, features, n_trial)
		goodness_per_label = []
		for label in [0.0, 1.0]:  # we have two labels only 1 and 0 for intol and outtol data
			print("Processing label:", label)
			if x_which.shape[1] == FUN_DIM + 1:
				x_which = torch.cat((x_which, torch.ones((x_which.shape[0],1)).to(DEVICE) * label), 1)
			else:
				x_which[:,-1] =  label
			  # reshape to (datapoints* n_trial, features)
			x_which = x_which.to(DEVICE)
			goodness = []
			for k, layer in enumerate(self.layers):
				if k == 0:
					g = layer(x_which, k)
					p = p_vector[k]
					goodness += [cos(g, p)]
				else:
					g = layer(g, k)
					p = p_vector[k]
					goodness += [cos(g, p)]

					print("Goodness shapes:", [g.shape for g in goodness])
			goodness_per_label += [torch.column_stack(goodness).sum(dim=1)] # Sum of all goodness for one label.
		mask = (goodness_per_label[1] < goodness_per_label[0]).to(DEVICE)
		mask = mask.reshape(-1,N_TRIAL)
		count = mask.sum(dim=-1)
		y_trial = y_trial.reshape(-1,N_TRIAL)
		# breakpoint()
		y_trial[~mask] = 0
		y_pred = y_trial.sum(dim=1) / (count + 1e-6)
		return y_pred
	
	def forward(self, x):
		for layer in self.layers:
			x = layer(x)
		return x

	def train(self, dataloader, n):
		k = 0
		for i, (x, y) in enumerate(dataloader):
			x, y = x.to(DEVICE), y.to(DEVICE)
			# x is tuple of x and y coordinates and y is label for real or fake data
			x_neg = torch.cat(
				(x[: y.shape[0]//2], y[: y.shape[0]//2].unsqueeze(1)), 1
			)  
			x_pos = torch.cat(
				(x[y.shape[0]//2 :], y[y.shape[0]//2 :].unsqueeze(1)), 1
			)
			h_pos, h_neg = x_pos, x_neg
			for layer in self.layers:
				h_pos, h_neg, loss = layer.train(h_pos, h_neg, k)
				k += 1
				print("Layer", k, "Loss", loss)

class Layer(nn.Module):
	def __init__(self, in_features, out_features, theta):
		super(Layer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.theta = theta
		self.layer = nn.Linear(in_features, out_features)
		self.main = nn.Sequential(self.layer, nn.GELU())
		self.layer_epochs = N_EPOCHS
		self.opt = torch.optim.Adam(self.layer.parameters(), lr=0.001)

	def forward(self, x, k):
		x = x.view(-1, self.in_features)
		return self.main(x)**2

	def goodness(self, x_pos, x_neg, k):
		h_pos = self.forward(x_pos, k)	# positive always coz of abs
		h_neg = self.forward(x_neg, k)	# negative always coz of abs
		cos = nn.CosineSimilarity()
		g_pos = cos(h_pos, p_vector[k].repeat(h_pos.shape[0], 1))
		g_neg = cos(h_neg, p_vector[k].repeat(h_neg.shape[0], 1))

		return g_pos, g_neg

	def train(self, x_pos, x_neg, k):
		self.running_loss = 0.0

		for i in range(self.layer_epochs):
			g_pos, g_neg = self.goodness(x_pos, x_neg, k)
			delta = g_pos - g_neg
			loss = (torch.log(1 + torch.exp(-self.theta * delta))).mean()
			self.opt.zero_grad()
			loss.backward()
			self.opt.step()
			self.running_loss += loss.item()
			lossy.append(self.running_loss / self.layer_epochs)
		return (
			self.forward(x_pos, k).detach(),
			self.forward(x_neg, k).detach(),
			self.running_loss / self.layer_epochs,
		)
def weights_init(m):
	if isinstance(m, nn.Linear):
		nn.init.xavier_normal_(m.weight)
		nn.init.zeros_(m.bias)
		
def get_dataset(df, N_OUTTOL, N_INTOL, TOL):

	if "area" not in df.columns:
		raise ValueError("DataFrame must contain 'area' as target column.")
	y_min = df["area"].min()
	y_max = df["area"].max()
	feature_cols = [c for c in df.columns if c != "area"]

	in_tol_rows = []
	out_tol_rows = []

	for _, row in df.iterrows():
		x_vals = row[feature_cols].values
		y0 = row["area"]   # target value
		y_low = y0 - TOL
		y_high = y0 + TOL
		y_intol = np.random.uniform(y_low, y_high, N_INTOL)
		for y in y_intol:
			in_tol_rows.append(list(x_vals) + [y])
		n_low = N_OUTTOL // 2
		n_high = N_OUTTOL - n_low
		# below lower bound
		if y_low > y_min:
			y_low_samples = np.random.uniform(y_min, y_low, n_low)
		else:
			y_low_samples = []
		# above upper bound
		if y_high < y_max:
			y_high_samples = np.random.uniform(y_high, y_max, n_high)
		else:
			y_high_samples = []
		for y in list(y_low_samples) + list(y_high_samples):
			out_tol_rows.append(list(x_vals) + [y])
	# Return constructed datasets
	df_intol = pd.DataFrame(in_tol_rows, columns=feature_cols + ["area"])
	df_outtol = pd.DataFrame(out_tol_rows, columns=feature_cols + ["area"])
	positive_data = pd.concat([df_intol.assign(label=1.0), df_outtol.assign(label=0.0)], ignore_index=True)
	negative_data = pd.concat([df_intol.assign(label=0.0), df_outtol.assign(label=1.0)], ignore_index=True)
	return positive_data, negative_data

if __name__ == '__main__':
	os.makedirs(RUN_ID, exist_ok=True)
	torch.manual_seed(SEED)
	## Create dataset
	x = torch.linspace(-1, 1, 25)
	y = torch.linspace(-1, 1, 25)
	z = torch.linspace(-1, 1, 25)
	x, y, z = torch.meshgrid(x, y, z, indexing='ij')

	df = pd.DataFrame({
		'x': x.flatten(),
		'y': y.flatten(),
		'z': z.flatten(),
		'area': function_3d(x,y,z).flatten()
	})


	positive_data, negative_data = get_dataset(df,N_OUTTOL,N_INTOL,TOL)
	positive_data = torch.tensor(positive_data.values, dtype=torch.float32)
	negative_data = torch.tensor(negative_data.values, dtype=torch.float32)
	all_data = torch.cat((positive_data, negative_data), 0)
	dataset = torch.utils.data.TensorDataset(all_data[:, :-1], all_data[:, -1])
	batch_size = len(dataset) # give complete data at once 
	dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
	
	#Define Model
	model= Net(DEVICE)
	PATH = f'{RUN_ID}/EPOCHS_{N_EPOCHS}.pth'
	
	if os.path.exists(PATH):
		print("Model found. Loading model...")
		model = torch.load(PATH,weights_only=False)
	else:
		print("Model not found. Training model...")
		model.train(dataloader, N)
		torch.save(model, PATH)
		print("Training done and model saved.")

	n_test = 500
	test_df = df.drop(columns=["area"]).sample(n=n_test, random_state=SEED)
	y_pred = model.predict(test_df,y_min=df["area"].min(),y_max=df["area"].max())
	y_actual = torch.tensor(df.sample(n=n_test, random_state=SEED)['area'].values).to(DEVICE)
	# Calculate Mean Squared Error
	y_pred = y_pred.reshape(y_actual.shape)
	mse = nn.MSELoss()
	test_loss = mse(y_pred, y_actual)
	print("MSE Loss on test data: ", test_loss.item())
	from sklearn.metrics import r2_score
	r2 = r2_score(y_actual.cpu().numpy(), y_pred.cpu().detach().numpy())
	print("R2 Score on test data: ", r2)

	import pandas as pd
	import seaborn as sns
	import matplotlib.pyplot as plt

	device = 'cuda' if torch.cuda.is_available() else 'cpu'
	sns.set(style="whitegrid")

	# Store data and MSEs
	plot_data = []
	mse_data = {}
	r2 = {}

	def evaluate_and_store(label, x1, x2, x3):
		x4_true = function_3d(x1, x2, x3)
		x4_pred = torch.zeros_like(x4_true)
		
		# Prepare DataFrame for prediction
		df_pred = pd.DataFrame({
			'x': x1.cpu().numpy(),
			'y': x2.cpu().numpy(),
			'z': x3.cpu().numpy()
		})
		x4_pred = model.predict(df_pred, y_min=df["area"].min(), y_max=df["area"].max()).to(device)

		# Compute MSE
		mse = ((x4_true - x4_pred) ** 2).mean().item()
		mse_data[label] = mse
		# compute R2
		r2[label] = r2_score(x4_true.cpu().numpy(), x4_pred.cpu().detach().numpy())


		t_vals = x1.detach().cpu().numpy()
		true_vals = x4_true.detach().cpu().numpy()
		pred_vals = x4_pred.detach().cpu().numpy()
		
		for t, true, pred in zip(t_vals, true_vals, pred_vals):
			plot_data.append({'t': t, 'value': true, 'type': 'True', 'diagonal': label})
			plot_data.append({'t': t, 'value': pred, 'type': 'Pred', 'diagonal': label})

	# Prepare t
	t = torch.linspace(-1, 1, 50).to(device)

	# Define all 10 diagonals
	evaluate_and_store("x1 = x2 = x3", t, t, t)
	evaluate_and_store("x1 = x2 = -x3", t, t, -t)
	evaluate_and_store("-x1 = x2 = x3", -t, t, t)
	evaluate_and_store("x1 = -x2 = x3", t, -t, t)
	evaluate_and_store("x1 = x2, x3 = 0", t, t, torch.zeros_like(t))
	evaluate_and_store("x1 = -x2, x3 = 0", t, -t, torch.zeros_like(t))
	evaluate_and_store("x1 = x3, x2 = 0", t, torch.zeros_like(t), t)
	evaluate_and_store("x1 = -x3, x2 = 0", t, torch.zeros_like(t), -t)


	# Convert to DataFrame
	df = pd.DataFrame(plot_data)
	unique_diagonals = list(df['diagonal'].unique())

	# --- Create 3x4 Grid of Subplots ---
	fig, axes = plt.subplots(2, 4, figsize=(18, 12))
	axes = axes.flatten()

	for i, diag in enumerate(unique_diagonals):
		ax = axes[i]
		diag_df = df[df['diagonal'] == diag]
		sns.lineplot(data=diag_df, x='t', y='value', hue='type', style='type', ax=ax, legend=False, palette='tab10')
		
		ax.set_title(f"{diag} R2: {r2[diag]:.4f}", fontsize=10)
		ax.set_xlabel("t")
		ax.set_ylabel("x4")
		ax.grid(True)

	# Hide unused subplots if any (there will be 2 empty in 3×4 for 10 plots)
	for j in range(len(unique_diagonals), len(axes)):
		fig.delaxes(axes[j])

	# Put legend outside
	handles, labels = axes[0].get_legend_handles_labels()
	fig.legend(handles, labels, loc='upper center', ncol=2, fontsize=12)

	plt.tight_layout(rect=[0, 0, 1, 0.95])
	plt.suptitle("4D Regression: True vs Predicted Across Diagonals", fontsize=16)
	plt.show()



	 