import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import copy

class SimpleNN(nn.Module):
    """
    A neural network with 5 layers total:
      - 4 hidden layers (all the same hidden_dim)
      - 1 output layer (output_dim)
    Provides a predict_proba method for classification (similar to scikit-learn).
    Includes dropout (default 0.1) for regularization.
    """
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Forward pass:
          x shape: [batch_size, input_dim]
        Returns raw logits of shape: [batch_size, output_dim]
        """
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = F.relu(self.fc4(x))
        x = self.dropout(x)
        logits = self.fc5(x)
        return logits

    def predict_proba(self, X):
        """
        Mimics scikit-learn's predict_proba:
         - Takes a NumPy array or a torch.Tensor of shape [n_samples, input_dim].
         - Runs a forward pass to get logits.
         - Applies softmax to get probabilities.
         - Returns a NumPy array of shape [n_samples, output_dim].
        """
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X).float()
        device = next(self.parameters()).device  # same device as model
        X = X.to(device)

        self.eval()
        with torch.no_grad():
            logits = self.forward(X)
            probs = F.softmax(logits, dim=1)
        return probs.cpu().numpy()


def compute_loss(model, criterion, X_val, Y_val, batch_size=64, device='cpu'):
    """
    Computes the average loss on a validation set (X_val, Y_val)
    using the given criterion, on the specified device.
    """
    dataset_val = TensorDataset(X_val, Y_val)
    loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

    model.eval()
    total_loss = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for xb, yb in loader_val:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            total_loss += loss.item() * xb.size(0)
            total_samples += xb.size(0)
    
    return total_loss / total_samples


def train_nn(
    X_train, Y_train, 
    X_val, Y_val,
    hidden_dim=32,
    epochs=15,
    batch_size=1000,
    learning_rate=1e-3,
    patience=4
):
    """
    Trains a simple neural network on (X_train, Y_train) with
    a validation set (X_val, Y_val). Uses early stopping if the
    validation loss does not improve for 'patience' consecutive epochs.

    Returns:
        model: the best model (SimpleNN) with .predict_proba method
    """

    # 1) Select device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 2) Convert data to torch.Tensors if needed
    if isinstance(X_train, np.ndarray):
        X_train = torch.from_numpy(X_train).float()
    if isinstance(Y_train, np.ndarray):
        Y_train = torch.from_numpy(Y_train).long()

    if isinstance(X_val, np.ndarray):
        X_val = torch.from_numpy(X_val).float()
    if isinstance(Y_val, np.ndarray):
        Y_val = torch.from_numpy(Y_val).long()

    # 3) Move data to device
    X_train = X_train.to(device)
    Y_train = Y_train.to(device)
    X_val = X_val.to(device)
    Y_val = Y_val.to(device)

    num_samples, input_dim = X_train.shape
    num_classes = int(Y_train.max().item()) + 1

    # 4) Create DataLoader for training
    dataset_train = TensorDataset(X_train, Y_train)
    loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

    # 5) Instantiate model & move to device
    model = SimpleNN(input_dim, hidden_dim, num_classes)
    model = model.to(device)

    # 6) Define optimizer & loss
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # 7) Early stopping setup
    best_val_loss = float('inf')
    best_model_state = copy.deepcopy(model.state_dict())
    no_improve_epochs = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for xb, yb in loader_train:
            # xb, yb are already on device from above DataLoader call?
            # Actually, DataLoader fetches CPU Tensors, so we need to move them again:
            xb, yb = xb.to(device), yb.to(device)

            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * xb.size(0)

        train_loss = total_loss / num_samples

        # Compute validation loss
        val_loss = compute_loss(model, criterion, X_val, Y_val, batch_size=batch_size, device=device)

        print(f"Epoch [{epoch+1}/{epochs}]  "
              f"Train Loss: {train_loss:.4f}  Val Loss: {val_loss:.4f}")

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    # Load the best model weights
    model.load_state_dict(best_model_state)
    return model

