import os

import timm
import torch

from dataset import CVDatasetManager
from train_function import basic_train, train_supernet_body
from modules_unac import CostumeResNetV2, ParallelConv2dV2, ParallelConv2dV1


def train_supernet(model_name='standard'):
    data_name = 'CIFAR100'
    batch_size = 36
    criterion = torch.nn.CrossEntropyLoss()
    model_shape = [3, 4, 6, 3]
    dataset_manager = CVDatasetManager(val_split=0)
    train_loader, val_loader, test_loader, n_class = dataset_manager.get_loader(data_name, batch_size)
    lr = 1e-3

    if model_name == 'standard':
        model = timm.create_model('resnet50', pretrained=False, num_classes=n_class)
    elif model_name == 'Parallel_v1':
        model = CostumeResNetV2(ParallelConv2dV1, model_shape, n_class, bottleneck=4, num_DoG=3,
                                costum_module_flag=True)
    elif model_name == 'Parallel_v2':
        model = CostumeResNetV2(ParallelConv2dV2, model_shape, n_class, bottleneck=4, num_DoG=3,
                                costum_module_flag=True)
    model_dir = os.path.join(os.environ.get("DATA", 'data/'), 'supernet_'+model_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    print('warm up...')
    model_file = (model_dir, 'None_1e-5.pth')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr / 100)
    model, _, _ = basic_train(model, train_loader, test_loader, criterion, optimizer, model_file=model_file, epochs=10)
    print('warmed up...')
    train_supernet_body(model, train_loader, test_loader, criterion, model_dir, lr)


if __name__ == '__main__':
    model_names = [
        'standard',
        'Parallel_v1',
        'Parallel_v2',
    ]
    for model_name in model_names:
        train_supernet(model_name)
