import torch
from torch import nn
import torch.nn.functional as F


class Flatten(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Clf_MNIST_SVHN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(  # input shape (3, 32, 32)
            nn.Conv2d(3, 10, kernel_size=4, stride=2, padding=1),  # -> (10, 16, 16)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=4, stride=2, padding=1),  # -> (20, 8, 8)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            Flatten(),  # -> (1280)
            nn.Linear(1280, 128),  # -> (128)
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(128, 10)  # -> (10)
        )

    def forward(self, x):
        h = self.encoder(x)
        return F.log_softmax(h, dim=-1)


def train_Clf_MNIST(dl, model, mod, epochs, device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()

    for epoch in range(epochs): 
        running_loss = 0.0
        for _, (svhn, mnist, y) in enumerate(dl['sup']):
            svhn, mnist, y = svhn.to(device), mnist.to(device), y.to(device)
            if mod=='svhn':
                data_batch = svhn
            else:
                data_batch = F.pad(mnist.view(-1, 1, 28, 28), (2, 2, 2, 2),
                                mode='constant', value=0).expand(-1, 3, -1, -1)

            optimizer.zero_grad()
            outputs = model(data_batch)

            loss = criterion(outputs, y.long())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print('====> Epoch: {:03d} loss: {:.3f}'.format(epoch, loss))
    return model


class Clf_CUBICC(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(  # input shape (3, 64, 64)
            nn.Conv2d(3, 10, kernel_size=4, stride=2, padding=1),  # -> (10, 32, 32)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=4, stride=2, padding=1),  # -> (20, 16, 16)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            nn.Conv2d(20, 20, kernel_size=4, stride=2, padding=1),  # -> (20, 8, 8)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            Flatten(),  # -> (1280)
            nn.Linear(1280, 128),  # -> (128)
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(128, 8)  # -> (8)
        )

    def forward(self, x):
        h = self.encoder(x)
        return F.log_softmax(h, dim=-1)
    

def train_Clf_CUBICC(dl, model, epochs, device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for _, (image, _, y) in enumerate(dl['sup']):
            data_batch, y = image.to(device), y.to(device)

            optimizer.zero_grad()
            outputs = model(data_batch)

            loss = criterion(outputs, y.long())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print('====> Epoch: {:03d} loss: {:.3f}'.format(epoch, loss))
    return model