import abc

import torch
from torch.utils.data import random_split

from src.models.cifar_densenets import DenseNet
from src.models.cifar_resnets import ResNet
from src.models.cifar_resnets_sd import ResNetSD
from src.models.cifar_wrn import WideResNet
from src.models.densenets import DenseNet121


class Experiment:
    def __init__(self, name, model, num_classes, classes, seed):
        self.name = name
        self.model = model
        self.num_classes = num_classes
        self.classes = classes
        self.path = None
        self.file = None
        self.seed = seed

    @abc.abstractmethod
    def load_data(self):
        pass

    def get_data_loaders(self, batch_size):
        trainset, testset = self.load_data()

        split_len = int(len(trainset) * 0.9)
        train_subset, val_subset = random_split(trainset, [split_len, len(trainset) - split_len],
                                                generator=torch.Generator().manual_seed(self.seed))

        train_loader = torch.utils.data.DataLoader(
            train_subset,
            batch_size=int(batch_size),
            shuffle=True,
            num_workers=2)
        val_loader = torch.utils.data.DataLoader(
            val_subset,
            batch_size=int(batch_size),
            shuffle=True,
            num_workers=2)
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=batch_size,
            shuffle=False)

        return train_loader, val_loader, test_loader

    def get_model(self, model_name, args):
        if model_name == 'resnet':
            return ResNet(args['depth'], num_classes=self.num_classes, block_name=args['block_name'])
        elif model_name == 'resnet_sd':
            return ResNetSD(args['depth'], death_mode='linear', death_rate=0.5, num_classes=self.num_classes)
        elif model_name == 'densenet':
            return DenseNet(num_classes=self.num_classes)
        elif model_name == 'densenet121':
            return DenseNet121(num_classes=self.num_classes)
        elif model_name == 'wideresnet':
            return WideResNet(args['depth'], args['num_classes'], args['widen_factor'], args['dropout'])
        else:
            print("Please enter a valid model name: resnet, densenet or convnet.")
            return

    def change_labels(self, labels):
        for which_c in self.classes:
            labels[labels == which_c] = -1
        labels[labels != -1] = 0
        labels[labels == -1] = 1

    def change_labels_3_classes(self, labels):
        labels[labels == self.classes[0]] = -1
        labels[labels == self.classes[1]] = -2
        labels[labels == self.classes[2]] = -3

        labels[labels == -1] = 0
        labels[labels == -2] = 1
        labels[labels == -3] = 2

    def change_labels_3_classes_grouped(self, labels, classes1, classes2, classes3):
        for c1 in classes1:
            labels[labels == c1] = -1
        for c2 in classes2:
            labels[labels == c2] = -2
        for c3 in classes3:
            labels[labels == c3] = -3

        labels[labels == -1] = 0
        labels[labels == -2] = 1
        labels[labels == -3] = 2
