import torch as T
from torch import nn
import torch.nn.functional as F

class CNNCifar(nn.Module):
    def __init__(self, num_classes):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


    def loss(self, x, targets):
        x = x.cuda() if T.cuda.is_available() else x

        # Run the forward pass
        output = self.forward(x)
        loss = T.nn.functional.cross_entropy(output, targets)

        return loss

    def acc(self, x, targets):
        self.eval()
        with T.no_grad():
            x = x.cuda() if T.cuda.is_available() else x
            Softmax = T.nn.Softmax(dim=1)
            # Run the forward pass
            output = self.forward(x)
            loss = T.nn.functional.cross_entropy(output, targets)
            _, preds = T.max(Softmax(output), 1)
            correct = T.sum(T.eq(preds, targets)).item() / len(preds)

        return correct, loss

    def get_logit(self, x = None, evalis = True, logmax=False):
        data, target = x

        if logmax:
            Softmax = T.nn.LogSoftmax(dim=1)
        else:
            Softmax = T.nn.Softmax(dim=1)

        data = data.cuda()
        if evalis:
            self.eval()
            with T.no_grad():
                # Run the forward pass
                output, _ = self.forward(data)
                logits = Softmax(output)

        else:
            self.train()
            output, _ = self.forward(data)
            logits = Softmax(output)

        loss = 1
        return None, logits.cpu(), target.cpu(), loss

