import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np



def compute_accuracy(net,testloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total


class NN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        #self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
       # output = self.fc1(x)
        #print(x.shape)
        x= x.view(-1, x[0].shape[0] * x[0].shape[1] * x[0].shape[1])

        output = F.relu(self.fc1(x))
        #output = m(self.fc2(output))
        output = self.out(output)

        #print('output', output)
        return output



class NN_2hidden(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(NN_2hidden, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.out = nn.Linear(hidden_size2, output_size)

    def forward(self, x):
       # output = self.fc1(x)
        x = x.view(-1, x[0].shape[0] * x[0].shape[1] * x[0].shape[1])
        output = F.relu(self.fc1(x))
        output = F.relu(self.fc2(output))
        output = self.out(output)
        return output


