import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

from .task import Task
from ..utils import log, log_dict


class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # flatten 1x28x28 images into 784. vectors
        output = F.log_softmax(self.linear(x), dim=1)
        return output


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class MnistTask(Task):

    def __init__(self, use_linear_model):
        self.use_linear_model = use_linear_model

    def get_model(self):
        if self.use_linear_model:
            return LinearModel()
        else:
            return Net()

    def get_loss_function(self):
        return torch.nn.modules.loss.NLLLoss()

    def get_dataloader(self):
        def mnist(data_dir,
                  train,
                  download,
                  batch_size,
                  shuffle=False,
                  sampler_callback=None,
                  dataset_cls=datasets.MNIST,
                  drop_last=True,
                  **loader_kwargs):
            # if sampler_callback is not None and shuffle is not None:
            #     raise ValueError

            dataset = dataset_cls(
                data_dir,
                train=train,
                download=download,
                transform=transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,)),
                    ]
                ),
            )

            sampler = sampler_callback(dataset) if sampler_callback else None
            log("Getting dataloader for mnist:")
            log_dict(
                {
                    "Type": "Setup",
                    "Dataset": "mnist",
                    "train": train,
                    "download": download,
                    "batch_size": batch_size,
                    "shuffle": shuffle,
                    "sampler": sampler.__str__() if sampler else None,
                }
            )
            return torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               sampler=sampler,
                                               drop_last=drop_last,
                                               **loader_kwargs)

        return mnist
