from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int, dropout: float = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc_out = nn.Linear(hidden, out_dim)
        self.dropout = nn.Dropout(dropout)

    def embedding(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.embedding(x)
        return self.fc_out(h)


@dataclass
class MLPConfig:
    hidden: int = 64
    out_dim: int = 2
    lr: float = 1e-3
    weight_decay: float = 0.0
    epochs: int = 50
    batch_size: int = 256
    device: str = "cpu"
    dropout: float = 0.0


def train_mlp_classifier(X: np.ndarray, y: np.ndarray, config: MLPConfig) -> MLP:
    dev = torch.device(config.device)
    in_dim = X.shape[1]
    num_classes = int(np.max(y)) + 1
    model = MLP(in_dim, config.hidden, num_classes, dropout=config.dropout).to(dev).double()

    opt = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    X_t = torch.tensor(X, dtype=torch.float64, device=dev)
    y_t = torch.tensor(y.astype(np.int64), dtype=torch.long, device=dev)

    model.train()
    for _ in range(config.epochs):
        # single-batch optimizer (data sizes are moderate); keep deterministic
        logits = model(X_t)
        loss = F.cross_entropy(logits, y_t)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
    return model




