import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import ConcatDataset, SubsetRandomSampler

from config import DATA_DIR
from .experiment import Experiment


class MNIST(Experiment):

    def __init__(self, args):
        model_name, num_classes = args['model_name'], args['num_classes']
        model = self.get_model(model_name, args)
        super().__init__('mnist', model, num_classes, args['classes'], args['seed'])

    def load_data(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.1307], [0.3081])
        ])

        trainset = torchvision.datasets.MNIST(root=DATA_DIR, download=True, train=True, transform=transform)
        testset = torchvision.datasets.MNIST(root=DATA_DIR, download=True, train=False, transform=transform)

        if self.num_classes == 1 and self.classes:
            self.change_labels(trainset.targets)
            self.change_labels(testset.targets)

        print(f'Unique labels in train and val: {torch.bincount(trainset.targets)}')
        print(f'Unique labels in test: {torch.bincount(testset.targets)}')

        return trainset, testset

    def get_concatenated_dataset(self):
        trainset, testset = self.load_data()
        return ConcatDataset([trainset, testset])

    @staticmethod
    def get_kfold_data_loaders(dataset, train_ids, test_ids, batch_size):
        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = SubsetRandomSampler(train_ids)
        test_subsampler = SubsetRandomSampler(test_ids)

        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_subsampler)

        return train_loader, test_loader


class LinearClassifier(nn.Module):
    def __init__(self, name, num_classes):
        super(LinearClassifier, self).__init__()
        self.name = name
        self.num_classes = num_classes
        self.hidden1 = nn.Linear(784, 392)
        self.hidden2 = nn.Linear(392, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.hidden1(x))

        return self.hidden2(x)


class ConvNet(nn.Module):
    def __init__(self, name, num_classes):
        super(ConvNet, self).__init__()
        self.name = name
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(1, 10, kernel_size=(5, 5))
        self.conv2 = nn.Conv2d(10, 10, kernel_size=(5, 5))
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(160, num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 160)

        return self.fc1(x)
