import logging

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset

from config import DATA_DIR
from .experiment import Experiment


class CIFAR(Experiment):

    def __init__(self, args):
        model_name, num_classes = args['model_name'], args['num_classes']

        model = self.get_model(model_name, args)
        super().__init__(f'cifar{num_classes}', model, num_classes, args['classes'], args['seed'])
        self.dataset = args['dataset']
        self.classes1 = args['classes1']
        self.classes2 = args['classes2']
        self.classes3 = args['classes3']

    def load_data(self):

        mean, std = self.get_mean_and_std()

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
        trainset, testset = self.get_datasets(transform_train, transform_test)

        trainset.targets = torch.tensor(trainset.targets)
        testset.targets = torch.tensor(testset.targets)

        if self.num_classes == 1 and self.classes:
            self.change_labels(trainset.targets)
            self.change_labels(testset.targets)
        if self.num_classes == 3:
            if self.classes:
                assert len(self.classes) == 3
                idx_train = (trainset.targets == self.classes[0]) + (trainset.targets == self.classes[1]) + \
                            (trainset.targets == self.classes[2])
                trainset = Subset(trainset, torch.where(idx_train)[0])
                idx_test = (testset.targets == self.classes[0]) + (testset.targets == self.classes[1]) +\
                           (testset.targets == self.classes[2])
                testset = Subset(testset, torch.where(idx_test)[0])

                self.change_labels_3_classes(trainset.dataset.targets)
                self.change_labels_3_classes(testset.dataset.targets)

            elif self.classes1 and self.classes2 and self.classes3:
                idx_train, idx_test = [], []
                for cl1 in self.classes1:
                    idx_train.append(torch.where(trainset.targets == cl1)[0])
                    idx_test.append(torch.where(testset.targets == cl1)[0])
                for cl2 in self.classes2:
                    idx_train.append(torch.where(trainset.targets == cl2)[0])
                    idx_test.append(torch.where(testset.targets == cl2)[0])
                for cl3 in self.classes3:
                    idx_train.append(torch.where(trainset.targets == cl3)[0])
                    idx_test.append(torch.where(testset.targets == cl3)[0])

                trainset = Subset(trainset, torch.cat(idx_train))
                testset = Subset(testset, torch.cat(idx_test))

                self.change_labels_3_classes_grouped(trainset.dataset.targets, self.classes1, self.classes2, self.classes3)
                self.change_labels_3_classes_grouped(testset.dataset.targets, self.classes1, self.classes2, self.classes3)

        return trainset, testset

    def get_mean_and_std(self):
        if self.dataset == 'cifar100' or self.num_classes == 100:
            mean = (0.5071, 0.4867, 0.4408)
            std = (0.2675, 0.2565, 0.2761)
        else:
            mean = (0.49139968, 0.48215841, 0.44653091)
            std = (0.24703223, 0.24348513, 0.26158784)

        return mean, std

    def get_datasets(self, transform_train, transform_test):
        if self.dataset == 'cifar10' or self.num_classes == 10:
            logging.info("Using data from cifar10")
            train = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transform_train)
            test = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform_test)
        else:
            logging.info("Using data from cifar100")
            train = torchvision.datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)
            test = torchvision.datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)

        return train, test


# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
class ConvNet(nn.Module):
    def __init__(self, name, num_classes):
        super().__init__()
        self.name = name
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3, 6, (5, 5))
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        return self.fc3(x)
