import torch
import torch.nn as nn
from torch.utils.data import Dataset
from random import sample

device = "cuda" if torch.cuda.is_available() else "cpu"


class LogReg(nn.Module):
    def __init__(self, input_dim):
        super(LogReg, self).__init__()
        self.linear = nn.Linear(input_dim, 1).double()

    def forward(self, x):
        out = self.linear(x)
        return torch.sigmoid(out)

    def score(self, x):
        return self.linear(x)


class MLP(nn.Module):
    def __init__(self, input_dim):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(input_dim, 5),
            nn.ReLU(),
            nn.Linear(5, 1)
        ).double()

    def forward(self, x):
        out = self.linear(x)
        return torch.sigmoid(out)

    def score(self, x):
        return self.linear(x)


class DatasetFromProgramTestCaseTesting(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.size(dim=0)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    avg_loss = 0
    for batch, (X, y) in enumerate(dataloader):

        # Compute prediction error
        pred = model(X)

        y = y.unsqueeze(1)

        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()

        # if batch % 20 == 0:
        #     loss_t, current = loss.item(), batch * len(X)
        #     print(f"loss: {loss_t:>7f}  [{current:>5d}/{size:>5d}]")
    return avg_loss/len(dataloader), batch * len(X)


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()

    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            y = y.unsqueeze(1)
            pred = model(X)

            test_loss += loss_fn(pred, y).item()
            correct += ((pred >= 0.5) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss
