import argparse
import json
import logging
import os

import torch
from torch.backends import cudnn
from torchdistill.common import yaml_util

from sc2bench.analysis import analyze_model_size
from sc2bench.common.config_util import overwrite_config
from sc2bench.models.backbone import check_if_updatable
from sc2bench.models.detection.registry import load_detection_model
from sc2bench.models.detection.wrapper import get_wrapped_detection_model
from sc2bench.models.registry import load_classification_model
from sc2bench.models.segmentation.registry import load_segmentation_model
from sc2bench.models.segmentation.wrapper import get_wrapped_segmentation_model
from sc2bench.models.wrapper import get_wrapped_classification_model

logging.disable(logging.INFO)


def get_argparser():
    parser = argparse.ArgumentParser(description='Tradeoff plotter for ImageNet dataset')
    parser.add_argument('--task', required=True, choices=['classification', 'detection', 'segmentation'],
                        help='target tasks')
    parser.add_argument('--device', default='cuda', help='device')
    parser.add_argument('--configs', nargs='+', help='yaml config files')
    parser.add_argument('--json', help='json string to overwrite config')
    parser.add_argument('--enc_paths', nargs='+', help='encoder module paths')
    parser.add_argument('--rest_paths', nargs='+', help='additional rest module paths')
    return parser


def load_classifier(model_config, device):
    if 'classification_model' not in model_config:
        return load_classification_model(model_config, device, False)
    return get_wrapped_classification_model(model_config, device, False)


def load_detector(model_config, device):
    if 'detection_model' not in model_config:
        return load_detection_model(model_config, device)
    return get_wrapped_detection_model(model_config, device)


def load_segmentor(model_config, device):
    if 'segmentation_model' not in model_config:
        return load_segmentation_model(model_config, device)
    return get_wrapped_segmentation_model(model_config, device)


def load_model(model_config, device, task):
    if task == 'classification':
        return load_classifier(model_config, device)
    elif task == 'detection':
        return load_detector(model_config, device)
    elif task == 'segmentation':
        return load_segmentor(model_config, device)
    raise ValueError(f'task: {task} is not expected')


def main(args):
    cudnn.benchmark = True
    cudnn.deterministic = True
    encoder_paths = args.enc_paths
    additional_rest_paths = args.enc_paths
    task = args.task
    print('config\tmodel_size\tencoder_size\trest_size')
    for config_file_path in args.configs:
        config = yaml_util.load_yaml_file(os.path.expanduser(config_file_path))
        if args.json is not None:
            overwrite_config(config, json.loads(args.json))

        device = torch.device(args.device)
        models_config = config['models']
        student_model_config =\
            models_config['student_model'] if 'student_model' in models_config else models_config['model']
        student_model = load_model(student_model_config, device, task)
        if check_if_updatable(student_model):
            student_model.update()

        model_size_dict = analyze_model_size(student_model, encoder_paths, additional_rest_paths)

        line = '\t'.join([config_file_path, str(model_size_dict['model']),
                          str(model_size_dict['encoder']), str(model_size_dict['rest'])])
        print(line)


if __name__ == '__main__':
    argparser = get_argparser()
    main(argparser.parse_args())
