import torch
import sys
import os
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_diabetes

data = load_diabetes()
df = pd.DataFrame(data.data, columns=data.feature_names)
df["target"] = data.target

# Boston Housing Dataset Regression
SEED = 43
torch.manual_seed(SEED)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 512
FUN_DIM = 10
RUN_ID = f"{FUN_DIM}D_Regression_Diabetes"
N = 100
N_EPOCHS = int(sys.argv[1]) # Take epochs from the first argument!

N_OUTTOL = 50
N_INTOL = 20
N_TRIAL = 500
THETA = 10 # hyperparameter
os.makedirs(RUN_ID, exist_ok=True)
torch.manual_seed(SEED)

lossy =[] # to append loss there

# Define the first vector
p_vector0 = torch.normal(0, 1, size=(1, 2 * HIDDEN_DIM)).to(DEVICE)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, HIDDEN_DIM)).to(DEVICE)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, HIDDEN_DIM // 2)).to(DEVICE)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize
 
p_vector = [p_vector0, p_vector1, p_vector2]  # List of all vectors

class Net(nn.Module):
	def __init__(self, DEVICE):
		super(Net, self).__init__()
		self.theta = THETA
		self.layers = []
		# 10 features + 12 months + 7 days + 1 target + 1 label = 31 input features
		dims = [FUN_DIM + 1 + 1, 2*HIDDEN_DIM, 2*HIDDEN_DIM, HIDDEN_DIM, HIDDEN_DIM , HIDDEN_DIM//2]
		for d in range(0, (len(dims)), 2):	# For three layers
			self.layers += [Layer(dims[d], dims[d + 1], self.theta).to(DEVICE)]

	def predict(self, xi,y_min,y_max):
		xi =  torch.tensor(xi.values, dtype=torch.float32)
		y_min, y_max = y_min - 0.05 * abs(y_max - y_min) , y_max + 0.05 * abs(y_max - y_min)
		print("y_min:", y_min, "y_max:", y_max)
		batch_size = xi.shape[0]
		n_features = xi.shape[1]
		cos = nn.CosineSimilarity(dim=1, eps=1e-6)
		y_trial = torch.linspace(y_min,y_max,N_TRIAL)
		xi = xi.repeat(1,N_TRIAL).to(DEVICE)
		xi = xi.flatten()
		xi = xi.reshape(-1,n_features)
		y_trial = y_trial.unsqueeze(-1).repeat(batch_size,1).to(DEVICE)
		x_which = torch.cat([xi, y_trial], dim=1)
		# x_which are all the trial datapoints with dimension (datapoints, features, n_trial)
		goodness_per_label = []
		for label in [0.0, 1.0]:  # we have two labels only 1 and 0 for intol and outtol data
			# print("Processing label:", label)
			if x_which.shape[1] == FUN_DIM + 1:
				x_which = torch.cat((x_which, torch.ones((x_which.shape[0],1)).to(DEVICE) * label), 1)
			else:
				x_which[:,-1] =  label
			  # reshape to (datapoints* n_trial, features)
			x_which = x_which.to(DEVICE)
			goodness = []
			for k, layer in enumerate(self.layers):
				if k == 0:
					# breakpoint()
					g = layer(x_which, k)
					p = p_vector[k]
					goodness += [cos(g, p)]
				else:
					g = layer(g, k)
					p = p_vector[k]
					goodness += [cos(g, p)]
					# print("Goodness shapes:", [g.shape for g in goodness])
			goodness_per_label += [torch.column_stack(goodness).sum(dim=1)] # Sum of all goodness for one label.
		mask = (goodness_per_label[1] < goodness_per_label[0]).to(DEVICE)
		mask = mask.reshape(-1,N_TRIAL)
		count = mask.sum(dim=-1)
		y_trial = y_trial.reshape(-1,N_TRIAL)
		y_trial[~mask] = 0
		y_pred = y_trial.sum(dim=1) / (count)
		return y_pred
	
	def forward(self, x):
		for layer in self.layers:
			x = layer(x)
		return x

	def train(self, dataloader):
		k = 0
		for i, (x, y) in enumerate(dataloader):
			x, y = x.to(DEVICE), y.to(DEVICE)
			# x is tuple of x and y coordinates and y is label for real or fake data
			x_neg = torch.cat(
				(x[: y.shape[0]//2], y[: y.shape[0]//2].unsqueeze(1)), 1
			)  
			x_pos = torch.cat(
				(x[y.shape[0]//2 :], y[y.shape[0]//2 :].unsqueeze(1)), 1
			)
			h_pos, h_neg = x_pos, x_neg
			for layer in self.layers:
				h_pos, h_neg, loss = layer.train(h_pos, h_neg, k)
				k += 1
				print("Layer", k, "Loss", loss)

class Layer(nn.Module):
	def __init__(self, in_features, out_features, theta):
		super(Layer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.theta = theta
		self.layer = nn.Linear(in_features, out_features)
		self.main = nn.Sequential(self.layer, nn.GELU())
		self.layer_epochs = N_EPOCHS
		self.opt = torch.optim.Adam(self.layer.parameters(), lr=0.001)

	def forward(self, x, k):
		x = x.view(-1, self.in_features)
		return self.main(x)**2

	def goodness(self, x_pos, x_neg, k):
		h_pos = self.forward(x_pos, k)	# positive always coz of abs
		h_neg = self.forward(x_neg, k)	# negative always coz of abs
		cos = nn.CosineSimilarity()
		# if k == 0:
		g_pos = cos(h_pos, p_vector[k].repeat(h_pos.shape[0], 1))
		g_neg = cos(h_neg, p_vector[k].repeat(h_neg.shape[0], 1))

		return g_pos, g_neg

	def train(self, x_pos, x_neg, k):
		self.running_loss = 0.0

		for i in range(self.layer_epochs):
			g_pos, g_neg = self.goodness(x_pos, x_neg, k)
			delta = g_pos - g_neg
			loss = (torch.log(1 + torch.exp(-self.theta * delta))).mean()
			self.opt.zero_grad()
			loss.backward()
			self.opt.step()
			self.running_loss += loss.item()
			lossy.append(self.running_loss / self.layer_epochs)
		return (
			self.forward(x_pos, k).detach(),
			self.forward(x_neg, k).detach(),
			self.running_loss / self.layer_epochs,
		)
def weights_init(m):
	if isinstance(m, nn.Linear):
		nn.init.xavier_normal_(m.weight)
		nn.init.zeros_(m.bias)
		
def get_dataset(df, N_OUTTOL, N_INTOL, TOL):

	y_min = df["target"].min()
	y_max = df["target"].max()
	feature_cols = [c for c in df.columns if c != "target"]

	in_tol_rows = []
	out_tol_rows = []

	for _, row in df.iterrows():
		x_vals = row[feature_cols].values
		y0 = row["target"]   # target value
		y_low = y0 - TOL
		y_high = y0 + TOL
		y_intol = np.random.uniform(y_low, y_high, N_INTOL)
		for y in y_intol:
			in_tol_rows.append(list(x_vals) + [y])
		n_low = N_OUTTOL // 2
		n_high = N_OUTTOL - n_low
		# below lower bound	
		if y_low > y_min-0.2*abs(y_min):
			y_low_samples = np.random.uniform(y_min-0.2*abs(y_min), y_low, n_low)
		else:
			y_low_samples = []
		# above upper bound
		if y_high < y_max+0.2*abs(y_max):
			y_high_samples = np.random.uniform(y_high, y_max+0.2*abs(y_max), n_high)
		else:
			y_high_samples = []


		for y in list(y_low_samples) + list(y_high_samples):
			out_tol_rows.append(list(x_vals) + [y])
	# Return constructed datasets
	df_intol = pd.DataFrame(in_tol_rows, columns=feature_cols + ["target"])
	df_outtol = pd.DataFrame(out_tol_rows, columns=feature_cols + ["target"])
	positive_data = pd.concat([df_intol.assign(label=1.0), df_outtol.assign(label=0.0)], ignore_index=True)
	negative_data = pd.concat([df_intol.assign(label=0.0), df_outtol.assign(label=1.0)], ignore_index=True)
	return positive_data, negative_data

if __name__ == '__main__':


	TOL = 0.01 * np.abs(df["target"].max() - df["target"].min())
	positive_data, negative_data = get_dataset(df,N_OUTTOL,N_INTOL,TOL)
	positive_data = torch.tensor(positive_data.values, dtype=torch.float32)
	negative_data = torch.tensor(negative_data.values, dtype=torch.float32)
	all_data = torch.cat((positive_data, negative_data), 0)
	dataset = torch.utils.data.TensorDataset(all_data[:, :-1], all_data[:, -1])
	batch_size = len(dataset) # give complete data at once 
	dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
	
	#Define Model
	model= Net(DEVICE)
	PATH = f'{RUN_ID}/EPOCHS_{N_EPOCHS}.pth'
	
	if os.path.exists(PATH):
		print("Model found. Loading model...")
		model = torch.load(PATH,weights_only=False)
	else:
		print("Model not found. Training model...")
		model.train(dataloader)
		torch.save(model, PATH)
		print("Training done and model saved.")

	test_df = df.drop(columns=["target"])#.sample(n=50, random_state=SEED)
	y_pred = model.predict(test_df,y_min=df["target"].min(),y_max=df["target"].max())
	y_actual = df["target"]#.sample(n=50, random_state=SEED)
	y_actual = torch.tensor(y_actual.values, dtype=torch.float32).to(DEVICE)
	# Calculate Mean Squared Error
	y_pred = y_pred.reshape(y_actual.shape)
	not_nan = ~torch.isnan(y_pred)
	mse = nn.MSELoss()
	mse_ff = mse(y_pred[not_nan], y_actual[not_nan])
	test_loss = mse_ff
	print("FF MSE Loss: ", test_loss.item())


	# Perform Random forest regression for comparison
	from sklearn.ensemble import RandomForestRegressor
	from sklearn.model_selection import train_test_split
	X_train = df.drop(columns=["target"])
	y_train = df["target"].values
	# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED)
	rf = RandomForestRegressor(n_estimators=100, random_state=SEED)
	rf.fit(X_train, y_train)
	y_pred_rf = rf.predict(X_train)
	mse_rf = mean_squared_error(y_train, y_pred_rf)
	print("Random Forest MSE Loss: ", mse_rf)
	print("R2: ", r2_score(y_train, y_pred_rf))

	
	from sklearn.decomposition import PCA
	pca = PCA(n_components=1)
	principal_components = pca.fit_transform(test_df)
	temp_input = np.sort(principal_components,axis = 0)
	sorted_indices = np.argsort(principal_components,axis = 0)
	if isinstance(y_pred, torch.Tensor):
		y_pred_np = y_pred.detach().cpu().numpy()
	else:
		y_pred_np = y_pred

	if isinstance(y_actual, torch.Tensor):
		y_actual_np = y_actual.detach().cpu().numpy()
	else:
		y_actual_np = y_actual

	plt.figure(figsize=(8,6))
	plt.scatter(y_pred_rf, y_actual_np, alpha=0.6,c="blue", label="Random Forest Prediction")
	plt.scatter(y_pred_np, y_actual_np, alpha=0.6,c="red", label="Forward Forward Prediction")

	plt.xlabel("Predicted Target variable")
	plt.ylabel("Actual Target variable")
	plt.suptitle("Diabetes Dataset")
	plt.title(f"FF MSE: {mse_ff:.4f}, Random Forest MSE: {mse_rf:.4f}")

	plt.legend()
	plt.grid(True)
	plt.show()