import time
import sys
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import optuna

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

# Ensures deterministic behavior on GPU (slower but reproducible)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def correlation_loss(y_pred, y_true):
    y_true = y_true.view(-1)
    y_pred = y_pred.view(-1)
    vx = y_pred - torch.mean(y_pred)
    vy = y_true - torch.mean(y_true)
    corr = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)) + 1e-8)
    return 1 - corr

def skew_symmetric_fn(l, r, dim):
    block = torch.tensor([[0., 1.], [-1., 0.]])
    A = torch.block_diag(*[block] * (dim // 2))
    A = A.to(l.device)
    l2 = torch.mm(l.float(), A)
    x = torch.mm(l2, r.t().float())
    x_diag = x.diag()
    return x_diag

def denormalize_tensor(tensor, mean, std):
    """
    Denormalizes a PyTorch tensor using given mean and std from training data.

    Args:
        tensor (torch.Tensor): Normalized input tensor of shape (N, D).
        mean (np.ndarray): Training mean of shape (1, D) or (D,).
        std (np.ndarray): Training std of shape (1, D) or (D,).

    Returns:
        torch.Tensor: Denormalized tensor.
    """
    mean_tensor = torch.tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std_tensor = torch.tensor(std, dtype=tensor.dtype, device=tensor.device)
    return tensor * std_tensor + mean_tensor


class CustomFunctionLayer(nn.Module):
    def __init__(self):
        super(CustomFunctionLayer, self).__init__()

    def forward(self, x1, x2, dim):
        return skew_symmetric_fn(x1, x2, dim)

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, input_size):
        self.X1 = X[:, 0:input_size]
        self.X2 = X[:, input_size:2*input_size]
        self.y = y
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X1[idx], self.X2[idx], self.y[idx]

class TwinNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_size):
        super(TwinNetwork, self).__init__()
        self.base_network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            #nn.Dropout(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            #nn.Dropout(0.2),
            #nn.Linear(hidden_size, hidden_size),
            #nn.BatchNorm1d(hidden_size),
            #nn.ReLU(),
            #nn.Dropout(0.2),
            nn.Linear(hidden_size, embedding_size)
        )

    def forward(self, input1):
        return self.base_network(input1)

class SiameseNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_size, seed=None):
        super(SiameseNetwork, self).__init__()
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
        self.twin_network = TwinNetwork(input_size, hidden_size, embedding_size)
        self.custom_layer = CustomFunctionLayer()
        self.embedding_size = embedding_size
    
    def forward(self, input1, input2):
        output1 = self.twin_network(input1)
        output2 = self.twin_network(input2)
        x = self.custom_layer(output1, output2, self.embedding_size)
        output = x.reshape(x.shape[0], 1)
        return output, input1, input2, output1, output2

g = torch.Generator()
g.manual_seed(seed)

# Load data (adjust your paths)
input_size = int(sys.argv[1]) #+ 2
embedding_size = int(sys.argv[2])
num_epochs = 100
patience = 30
alpha = 0
input1 = np.loadtxt("../left_items.txt")
input2 = np.loadtxt("../right_items.txt")
labels = np.loadtxt("../y_original.txt")
labels_block = np.loadtxt("../y_block.txt")

input1_pruned = input1[:,0:input_size]
input2_pruned = input2[:,0:input_size]

X = np.concatenate((input1_pruned, input2_pruned), axis=1)
all_indices = np.arange(len(labels))
X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
    X, labels, all_indices, test_size=0.25, random_state=42)

train_mean = X_train.mean(axis=0, keepdims=True)
train_std = X_train.std(axis=0, keepdims=True) + 1e-6
X_train = (X_train - train_mean) / train_std
X_test = (X_test - train_mean) / train_std

y_train_mean = y_train.mean()
y_train_std = y_train.std() + 1e-6
y_train_normalized = (y_train - y_train_mean) / y_train_std
y_test_normalized = (y_test - y_train_mean) / y_train_std

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_normalized, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test_normalized, dtype=torch.float32)

batch_size = 64
train_dataset = CustomDataset(X_train_tensor, y_train_tensor, input_size)
test_dataset = CustomDataset(X_test_tensor, y_test_tensor, input_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

def objective(trial):

    seed = trial.number #42
    print("trial number = ", trial.number)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    hidden_size = trial.suggest_categorical("hidden_size", [32, 64, 96, 128])
    learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    #optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])

    model = SiameseNetwork(input_size, hidden_size, embedding_size, seed)

    #if optimizer_name == "Adam":
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    #else:
    #    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)

    criterion = nn.MSELoss()

    best_val_loss = float("inf")
    patience_counter = 0
    trial_train_losses = []
    trial_val_losses = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for x1, x2, targets in train_loader:
            optimizer.zero_grad()
            outputs, _, _, _, _ = model(x1, x2)
            loss_mse = criterion(outputs, targets.unsqueeze(1))
            loss_corr = correlation_loss(outputs, targets.unsqueeze(1))
            loss = loss_mse + alpha * loss_corr
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * x1.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)

        # Validation
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for x1, x2, targets in test_loader:
                outputs, _, _, _, _ = model(x1, x2)
                loss_mse = criterion(outputs, targets.unsqueeze(1))
                loss_corr = correlation_loss(outputs, targets.unsqueeze(1))
                loss = loss_mse + alpha * loss_corr
                val_loss += loss.item() * x1.size(0)
        epoch_val_loss = val_loss / len(test_loader.dataset)

        # Save losses for plotting later
        trial_train_losses.append(epoch_train_loss)
        trial_val_losses.append(epoch_val_loss)

        # Early stopping
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break

    # Save loss curves to trial for later plotting
    trial.set_user_attr("train_losses", trial_train_losses)
    trial.set_user_attr("val_losses", trial_val_losses)

    return best_val_loss


# Optuna sampler with fixed seed
sampler = optuna.samplers.TPESampler(seed=42)
study = optuna.create_study(direction="minimize", sampler=sampler)
study.optimize(objective, n_trials=100, timeout=3600)

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print("  Params:")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

best_trial = study.best_trial
train_losses = best_trial.user_attrs["train_losses"]
val_losses = best_trial.user_attrs["val_losses"]

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="Train Loss", color="blue")
plt.plot(val_losses, label="Validation Loss", color="red")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train/Validation Loss for Best Trial")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("optuna_best_model_loss_curve.png")
plt.show()


# You can add code here to retrain your best model on full data and save results or plots
# === Re-train best model on full data ===
best_params = study.best_trial.params
best_model = SiameseNetwork(input_size, best_params['hidden_size'], embedding_size)

#if best_params['optimizer'] == "Adam":
optimizer = optim.Adam(best_model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'])
#else:
#    optimizer = optim.SGD(best_model.parameters(), lr=best_params['learning_rate'], momentum=0.9, weight_decay=best_params['weight_decay'])

criterion = nn.MSELoss()
#alpha = best_params['alpha']
# Ensure DataLoader shuffling is deterministic
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

best_model.train()
best_val_loss = float("inf")
patience_counter = 0

for epoch in range(num_epochs):
    epoch_loss = 0
    for x1, x2, targets in train_loader:
        optimizer.zero_grad()
        outputs, _, _, _, _ = best_model(x1, x2)
        loss_mse = criterion(outputs, targets.unsqueeze(1))
        loss_corr = correlation_loss(outputs, targets.unsqueeze(1))
        loss = loss_mse + alpha * loss_corr
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * x1.size(0)
    epoch_loss /= len(train_loader.dataset)

    # Early stopping on validation loss
    best_model.eval()
    val_loss = 0
    with torch.no_grad():
        for x1, x2, targets in test_loader:
            outputs, _, _, _, _ = best_model(x1, x2)
            loss_mse = criterion(outputs, targets.unsqueeze(1))
            loss_corr = correlation_loss(outputs, targets.unsqueeze(1))
            loss = loss_mse + alpha * loss_corr
            val_loss += loss.item() * x1.size(0)
    val_loss /= len(test_loader.dataset)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_state_dict = best_model.state_dict()
    else:
        patience_counter += 1
        if patience_counter > patience:
            break

# Restore best weights
best_model.load_state_dict(best_state_dict)
best_model.eval()

# Cleaned and Final Output Saving Code Block

# --- Prepare full data reconstruction and validation ---
X_full = np.concatenate((X_train, X_test), axis=0)
y_full = np.concatenate((y_train, y_test), axis=0)
idx_full = np.concatenate((idx_train, idx_test), axis=0)
sort_idx = np.argsort(idx_full)

# Reorder based on original index positions
X_full_ordered = X_full[sort_idx]
y_full_ordered = y_full[sort_idx]
labels_block_full = np.concatenate((labels_block[idx_train], labels_block[idx_test]), axis=0)
labels_block_ordered = labels_block_full[sort_idx]

# Restore input1 and input2 order
input1_full = np.concatenate((input1[idx_train], input1[idx_test]), axis=0)
input2_full = np.concatenate((input2[idx_train], input2[idx_test]), axis=0)
input1_full_ordered = input1_full[sort_idx]
input2_full_ordered = input2_full[sort_idx]

# Construct full normalized tensor
X_full_ordered = np.concatenate([input1_full_ordered[:, :input_size], input2_full_ordered[:, :input_size]], axis=1)
X_full_tensor = torch.tensor((X_full_ordered - train_mean) / train_std, dtype=torch.float32)
y_full_tensor = torch.tensor((y_full_ordered - y_train_mean) / y_train_std, dtype=torch.float32)

# Create dataset and loader
merged_dataset = CustomDataset(X_full_tensor, y_full_tensor, input_size)
merged_loader = DataLoader(merged_dataset, batch_size=batch_size, shuffle=False)

# Collect outputs from trained model
f1_all, f2_all, phi_1_all, phi_2_all, y_all = [], [], [], [], []

with torch.no_grad():
    for x1, x2, y in merged_loader:
        _, _, _, phi1, phi2 = best_model(x1, x2)
        f1_all.append(x1)
        f2_all.append(x2)
        phi_1_all.append(phi1)
        phi_2_all.append(phi2)
        y_all.append(y)

f1_all = torch.cat(f1_all, dim=0)
f2_all = torch.cat(f2_all, dim=0)
phi_1_all = torch.cat(phi_1_all, dim=0)
phi_2_all = torch.cat(phi_2_all, dim=0)
y_all = torch.cat(y_all, dim=0).unsqueeze(1)

# Denormalize f1 and f2
train_mean_tensor = torch.tensor(train_mean[0, :input_size], dtype=torch.float32, device=f1_all.device)
train_std_tensor = torch.tensor(train_std[0, :input_size], dtype=torch.float32, device=f1_all.device)
f1_all = f1_all * train_std_tensor + train_mean_tensor
f2_all = f2_all * train_std_tensor + train_mean_tensor

# Denormalize y
y_all = y_all * y_train_std + y_train_mean

# Combine all input features (denormalized)
final_inputs = np.concatenate([input1_full_ordered, input2_full_ordered], axis=1)

# Combine model outputs (embeddings + predicted y), convert to numpy
final_outputs = torch.cat([phi_1_all, phi_2_all, y_all], dim=1).cpu().numpy()

# Ensure labels shape matches number of rows in inputs/outputs
final_labels = labels_block_ordered[:final_inputs.shape[0]]
if final_labels.ndim == 1:
   final_labels = final_labels.reshape(-1, 1)

print("input1_full_ordered shape:", input1_full_ordered.shape)
print("input2_full_ordered shape:", input2_full_ordered.shape)
print("final_inputs shape:", final_inputs.shape)  # Should be number_of_samples x 2*input_size
print("labels_block_ordered shape, final_labels shape :", labels_block_ordered.shape, final_labels.shape)


# Sanity checks before concatenation
assert final_inputs.shape[0] == final_outputs.shape[0], "Mismatch in number of samples (inputs vs outputs)"
assert final_inputs.shape[0] == final_labels.shape[0], "Mismatch in number of samples (inputs vs labels)"

# Concatenate horizontally: inputs | outputs | labels
final_array = np.concatenate([final_inputs, final_outputs, final_labels], axis=1)

# Save to text file (tab-separated or default space-separated)
np.savetxt("../merged_output.txt", final_array)

print(f"Saved merged_output.txt with shape {final_array.shape} (samples x features+outputs+labels).")
