import torch

from dataset import CVDatasetManager
from train_function import basic_train
from modules_ac import CostumeResNetV2, CostumCNN, CostumeResNetV3


def train_full_resnet50(data_name):
    epochs_dict = {'CIFAR100': [100, 200], 'CIFAR10': [50, 200], 'FMNIST': [50, 1000], 'MNIST': [20, 1000]}
    epochs, batch_size = epochs_dict[data_name]
    dataset_manager = CVDatasetManager(val_split=0)
    train_loader, val_loader, test_loader, n_class = dataset_manager.get_loader(data_name, batch_size)
    criterion = torch.nn.CrossEntropyLoss()
    print('full resnet model ResNet50:')
    model_shape = [3, 4, 6, 3]
    if data_name in ['CIFAR100', 'CIFAR10']:
        model = CostumeResNetV2(CostumCNN, model_shape, n_class, 4)
    else:
        model = CostumeResNetV3(CostumCNN, model_shape, n_class, 4)
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    _, best_loss, best_accuracy = basic_train(model, train_loader, test_loader, criterion, optimizer, epochs, )
    return best_loss, best_accuracy


def train_all_full_resnet50():
    data_names = ['CIFAR100', 'CIFAR10', 'FMNIST', 'MNIST', ]
    results = {}
    for data_name in data_names:
        best_loss, best_accuracy = train_full_resnet50(data_name)
        results[data_name] = (best_loss, best_accuracy)
        print(results)


if __name__ == '__main__':
    train_all_full_resnet50()
