import json

import torch
import torch.nn as nn
from torch.optim import Adam

from train_function import basic_train
from dataset import CVDatasetManager
from modules_seq import BasicResidualV1A, BasicResidualV1B, BasicResidualV1C, BasicResidualV1D, ResNetV1
from modules_seq import BasicResidualV2A, BasicResidualV2B, BasicResidualV2C, BasicResidualV2D, ResNetV2


def train_one_side(model, module, model_shape, num_classes, train_loader, test_loader, epochs=20):
    model = model(module, model_shape, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-3)
    _, _, best_accuracy = basic_train(model, train_loader, test_loader, criterion, optimizer, epochs)
    return best_accuracy


def train_one(model_name, train_loader, test_loader, num_classes, epochs=20):
    model_shapes = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}
    if model_name in [18, 34]:
        moduleA = BasicResidualV1A
        moduleB = BasicResidualV1B
        moduleC = BasicResidualV1C
        moduleD = BasicResidualV1D
        model = ResNetV1
    elif model_name in [50, 101, 152]:
        moduleA = BasicResidualV2A
        moduleB = BasicResidualV2B
        moduleC = BasicResidualV2C
        moduleD = BasicResidualV2D
        model = ResNetV2
    else:
        raise ValueError("model_name must be in [18, 34, 50, 101, 152]")
    model_shape = model_shapes[model_name]
    print('A:')
    best_accuracyA = train_one_side(model, moduleA, model_shape, num_classes, train_loader, test_loader, epochs)
    print('B:')
    best_accuracyB = train_one_side(model, moduleB, model_shape, num_classes, train_loader, test_loader, epochs)
    print('C:')
    best_accuracyC = train_one_side(model, moduleC, model_shape, num_classes, train_loader, test_loader, epochs)
    print('D:')
    best_accuracyD = train_one_side(model, moduleD, model_shape, num_classes, train_loader, test_loader, epochs)
    return best_accuracyA, best_accuracyB, best_accuracyC, best_accuracyD


def train_seq():
    data_manager = CVDatasetManager(0)
    data_names = data_manager.get_names()
    results_dict = {data_name: {18: [], 50: []} for data_name in data_names}
    epochs_dict = {'CIFAR100': [100, 200], 'CIFAR10': [50, 200], 'FMNIST': [50, 1000], 'MNIST': [20, 1000]}
    for _ in range(10):
        for data_name in data_names:
            epochs, batch_size = epochs_dict[data_name]
            print('\n', data_name, ": \n")
            train_loader, val_loader, test_loader, n_class = data_manager.get_loader(data_name, batch_size)
            for model_name in [50, 18]:
                print(model_name)
                results = train_one(model_name, train_loader, test_loader, n_class, epochs)
                results_dict[data_name][model_name].append(results)
                print(results_dict)
                with open('ez_model_results_dict.txt', 'w') as f:
                    json.dump(results_dict, f)


if __name__ == '__main__':
    train_seq()
