import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import time
import sys
import random
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import optuna


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

# Ensures deterministic behavior on GPU (slower but reproducible)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

def skew_symmetric_fn(l, r, dim):
    block = torch.tensor([[0., 1.], [-1., 0.]])
    A = torch.block_diag(*[block] * (dim // 2))
    A = A.to(l.device)
    l2 = torch.mm(l.float(), A)
    x = torch.mm(l2, r.t().float())
    x_diag = x.diag()
    return x_diag

def denormalize_tensor(tensor, mean, std):
    """
    Denormalizes a PyTorch tensor using given mean and std from training data.

    Args:
        tensor (torch.Tensor): Normalized input tensor of shape (N, D).
        mean (np.ndarray): Training mean of shape (1, D) or (D,).
        std (np.ndarray): Training std of shape (1, D) or (D,).

    Returns:
        torch.Tensor: Denormalized tensor.
    """
    mean_tensor = torch.tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std_tensor = torch.tensor(std, dtype=tensor.dtype, device=tensor.device)
    return tensor * std_tensor + mean_tensor

    
class CustomFunctionLayer(nn.Module):
    def __init__(self):
        super(CustomFunctionLayer, self).__init__()

    def forward(self, x1, x2, dim):
        #return torch.sigmoid(skew_symmetric_fn(x1, x2, dim))
        return skew_symmetric_fn(x1, x2, dim)

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, input_size):
        self.X1 = X[:, 0:input_size]
        self.X2 = X[:, input_size:2*input_size]
        self.y = y
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X1[idx], self.X2[idx], self.y[idx]


class TwinNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_size):
        super(TwinNetwork, self).__init__()
        self.base_network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            #nn.Dropout(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            #nn.Dropout(0.2),
            #nn.Linear(hidden_size, hidden_size),
            #nn.BatchNorm1d(hidden_size),
            #nn.ReLU(),
            #nn.Dropout(0.2),
            nn.Linear(hidden_size, embedding_size)
        )

    def forward(self, input1):
        return self.base_network(input1)

class SiameseNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_size, seed=None):
        super(SiameseNetwork, self).__init__()
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
        self.twin_network = TwinNetwork(input_size, hidden_size, embedding_size)
        self.custom_layer = CustomFunctionLayer()
        self.embedding_size = embedding_size
    
    def forward(self, input1, input2):
        output1 = self.twin_network(input1)
        output2 = self.twin_network(input2)
        x = self.custom_layer(output1, output2, self.embedding_size)
        output = x.reshape(x.shape[0], 1)
        return output, input1, input2, output1, output2

g = torch.Generator()
g.manual_seed(seed)
# -------------------------
# Load preprocessed Pokémon data
# -------------------------
input1 = np.load("pokemon/l_processed.npy")
input2 = np.load("pokemon/r_processed.npy")
S = np.load("pokemon/S.npy")
#y = np.load("pokemon/y.npy")  # classification labels 0/1
# Load labels
labels = np.load("pokemon/y.npy")  # original labels: -1 and 1

# Convert -1 → 0, keep 1 → 1
labels = ((labels + 1) // 2).astype(np.int64)  


# -------------------------
# Model setup
# -------------------------
input_size = input1.shape[1]      # number of features per Pokémon
hidden_size = 128            # hidden layer size
embedding_size = int(sys.argv[1])
num_samples = input1.shape[0]            
print("input_size, num_samples = ", input_size, num_samples)


# -------------------------
# Training loop
# -------------------------
num_epochs = 100
patience = 30
alpha = 0
batch_size = 64

input1_pruned = input1[:,0:input_size]
input2_pruned = input2[:,0:input_size]

X = np.concatenate((input1_pruned, input2_pruned), axis=1)
all_indices = np.arange(len(labels))
X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
    X, labels, all_indices, test_size=0.25, random_state=42)

train_mean = X_train.mean(axis=0, keepdims=True)
train_std = X_train.std(axis=0, keepdims=True) + 1e-6
X_train = (X_train - train_mean) / train_std
X_test = (X_test - train_mean) / train_std

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

train_dataset = CustomDataset(X_train_tensor, y_train_tensor, input_size)
test_dataset = CustomDataset(X_test_tensor, y_test_tensor, input_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criterion = nn.BCEWithLogitsLoss()

def objective(trial):

    seed = 42 #trial.number #42
    print("trial number = ", seed) ##trial.number)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    hidden_size = 128     #trial.suggest_categorical("hidden_size", [32, 64, 96, 128])
    learning_rate = 0.001  #trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True)
    weight_decay = 1e-4     #trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    #optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])

    model = SiameseNetwork(input_size, hidden_size, embedding_size, seed)

    #if optimizer_name == "Adam":
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    #else:
    #    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)

    best_val_loss = float("inf")
    patience_counter = 0
    trial_train_losses = []
    trial_val_losses = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for x1, x2, targets in train_loader:
            optimizer.zero_grad()
            outputs, _, _, _, _ = model(x1, x2)
            loss = criterion(outputs, targets.unsqueeze(1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * x1.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)

        # Validation
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for x1, x2, targets in test_loader:
                outputs, _, _, _, _ = model(x1, x2)
                loss = criterion(outputs, targets.unsqueeze(1))
                val_loss += loss.item() * x1.size(0)
        epoch_val_loss = val_loss / len(test_loader.dataset)

        # Save losses for plotting later
        trial_train_losses.append(epoch_train_loss)
        trial_val_losses.append(epoch_val_loss)

        # Early stopping
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break

    # Save loss curves to trial for later plotting
    trial.set_user_attr("train_losses", trial_train_losses)
    trial.set_user_attr("val_losses", trial_val_losses)

    return best_val_loss


# Optuna sampler with fixed seed
sampler = optuna.samplers.TPESampler(seed=42)
study = optuna.create_study(direction="minimize", sampler=sampler)
study.optimize(objective, n_trials=1, timeout=3600)

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print("  Params:")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

best_trial = study.best_trial
train_losses = best_trial.user_attrs["train_losses"]
val_losses = best_trial.user_attrs["val_losses"]

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="Train Loss", color="blue")
plt.plot(val_losses, label="Validation Loss", color="red")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train/Validation Loss for Best Trial")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("optuna_best_model_loss_curve.png")
plt.show()


# You can add code here to retrain your best model on full data and save results or plots
# === Re-train best model on full data ===
best_params = study.best_trial.params
best_model = SiameseNetwork(input_size, 128, embedding_size)

#if best_params['optimizer'] == "Adam":
optimizer = optim.Adam(best_model.parameters(), lr=0.001, weight_decay=1e-4)
#else:
#    optimizer = optim.SGD(best_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)

# Ensure DataLoader shuffling is deterministic
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

best_model.train()
best_val_loss = float("inf")
patience_counter = 0

for epoch in range(num_epochs):
    epoch_loss = 0
    for x1, x2, targets in train_loader:
        optimizer.zero_grad()
        outputs, _, _, _, _ = best_model(x1, x2)
        loss = criterion(outputs, targets.unsqueeze(1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * x1.size(0)
    epoch_loss /= len(train_loader.dataset)

    # Early stopping on validation loss
    best_model.eval()
    val_loss = 0
    with torch.no_grad():
        for x1, x2, targets in test_loader:
            outputs, _, _, _, _ = best_model(x1, x2)
            loss = criterion(outputs, targets.unsqueeze(1))
            val_loss += loss.item() * x1.size(0)
    val_loss /= len(test_loader.dataset)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_state_dict = best_model.state_dict()
    else:
        patience_counter += 1
        if patience_counter > patience:
            break

# Restore best weights
best_model.load_state_dict(best_state_dict)
best_model.eval()

# Cleaned and Final Output Saving Code Block

# --- Prepare full data reconstruction and validation ---
X_full = np.concatenate((X_train, X_test), axis=0)
y_full = np.concatenate((y_train, y_test), axis=0)
idx_full = np.concatenate((idx_train, idx_test), axis=0)
sort_idx = np.argsort(idx_full)

# Reorder based on original index positions
X_full_ordered = X_full[sort_idx]
y_full_ordered = y_full[sort_idx]

# Restore input1 and input2 order
input1_full = np.concatenate((input1[idx_train], input1[idx_test]), axis=0)
input2_full = np.concatenate((input2[idx_train], input2[idx_test]), axis=0)
input1_full_ordered = input1_full[sort_idx]
input2_full_ordered = input2_full[sort_idx]

# Construct full normalized tensor
X_full_ordered = np.concatenate([input1_full_ordered[:, :input_size], input2_full_ordered[:, :input_size]], axis=1)
X_full_tensor = torch.tensor((X_full_ordered - train_mean) / train_std, dtype=torch.float32)
y_full_tensor = torch.tensor(y_full_ordered, dtype=torch.float32)

# Create dataset and loader
merged_dataset = CustomDataset(X_full_tensor, y_full_tensor, input_size)
merged_loader = DataLoader(merged_dataset, batch_size=batch_size, shuffle=False)

# Collect outputs from trained model
f1_all, f2_all, phi_1_all, phi_2_all, y_all = [], [], [], [], []

with torch.no_grad():
    for x1, x2, y in merged_loader:
        _, _, _, phi1, phi2 = best_model(x1, x2)
        f1_all.append(x1)
        f2_all.append(x2)
        phi_1_all.append(phi1)
        phi_2_all.append(phi2)
        y_all.append(y)

f1_all = torch.cat(f1_all, dim=0)
f2_all = torch.cat(f2_all, dim=0)
phi_1_all = torch.cat(phi_1_all, dim=0)
phi_2_all = torch.cat(phi_2_all, dim=0)
y_all = torch.cat(y_all, dim=0).unsqueeze(1)

# Denormalize f1 and f2
train_mean_tensor = torch.tensor(train_mean[0, :input_size], dtype=torch.float32, device=f1_all.device)
train_std_tensor = torch.tensor(train_std[0, :input_size], dtype=torch.float32, device=f1_all.device)
f1_all = f1_all * train_std_tensor + train_mean_tensor
f2_all = f2_all * train_std_tensor + train_mean_tensor

# Combine all input features (denormalized)
final_inputs = np.concatenate([input1_full_ordered, input2_full_ordered], axis=1)

# Combine model outputs (embeddings + predicted y), convert to numpy
final_outputs = torch.cat([phi_1_all, phi_2_all, y_all], dim=1).cpu().numpy()

print("input1_full_ordered shape:", input1_full_ordered.shape)
print("input2_full_ordered shape:", input2_full_ordered.shape)
print("final_inputs shape:", final_inputs.shape)  # Should be number_of_samples x 2*input_size


# Sanity checks before concatenation
assert final_inputs.shape[0] == final_outputs.shape[0], "Mismatch in number of samples (inputs vs outputs)"

# Concatenate horizontally: inputs | outputs
final_array = np.concatenate([final_inputs, final_outputs], axis=1)

# Save to text file (tab-separated or default space-separated)
np.savetxt("merged_output.txt", final_array)

print(f"Saved merged_output.txt with shape {final_array.shape} (samples x features+outputs).")

