import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

from .task import Task
from ..utils import log, log_dict


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 26)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class EmnistTask(Task):

    def get_model(self):
        return Net()

    def get_loss_function(self):
        return torch.nn.modules.loss.NLLLoss()

    def get_dataloader(self):
        def emnist(data_dir,
                  train,
                  download,
                  batch_size,
                  shuffle=None,
                  sampler_callback=None,
                  dataset_cls=datasets.EMNIST,
                  drop_last=True,
                  **loader_kwargs):
            # if sampler_callback is not None and shuffle is not None:
            #     raise ValueError

            dataset = dataset_cls(
                data_dir,
                train=train,
                download=download,
                split="letters", # assuming letters split for now.
                transform=transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize((0.2,), (0.3499,)),
                    ]
                ),
            )
            dataset.targets = dataset.targets - 1 # targets for emnist go from 1 to 26 by default which causes problems with CE loss.

            sampler = sampler_callback(dataset) if sampler_callback else None
            log("Getting dataloader for emnist:")
            log_dict(
                {
                    "Type": "Setup",
                    "Dataset": "emnist",
                    "train": train,
                    "download": download,
                    "batch_size": batch_size,
                    "shuffle": shuffle,
                    "sampler": sampler.__str__() if sampler else None,
                }
            )
            return torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               sampler=sampler,
                                               drop_last=drop_last,
                                               **loader_kwargs)

        return emnist


if __name__ == "__main__":
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    model = Net()
    print(count_parameters(model))