# Modified based on the HRNet repo.

import argparse
import os
import pprint
import shutil
import sys

import logging
import time
import timeit
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

from deq.mdeq_vision.lib import models
from deq.mdeq_vision.lib import datasets
from deq.mdeq_vision.lib.config import config
from deq.mdeq_vision.lib.config import update_config
from deq.mdeq_vision.lib.core.seg_function import testval, test
from deq.mdeq_vision.lib.utils.modelsummary import get_model_summary
from deq.mdeq_vision.lib.utils.utils import create_logger

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')

    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)
    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--testModel',
                        help='testModel',
                        type=str,
                        default='')
    parser.add_argument('--percent',
                        help='percentage of training data to use',
                        type=float,
                        default=1.0)
    parser.add_argument('--results_name',
                        help='file in which to store the accuracy and hyperparameters',
                        type=str,
                        default=None)
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()
    update_config(config, args)

    return args

def main():
    args = parse_args()

    logger, final_output_dir, _ = create_logger(
        config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # build model
    model = eval('models.'+config.MODEL.NAME +
                 '.get_seg_net')(config)

    dump_input = torch.rand(
        (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    )
    if torch.cuda.is_available():
        logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    if config.TEST.MODEL_FILE:
        model_state_file = config.TEST.MODEL_FILE
    else:
        model_state_file = os.path.join(final_output_dir,
                                        'final_state.pth')
    logger.info('=> loading model from {}'.format(model_state_file))

    if torch.cuda.is_available():
        pretrained_dict = torch.load(model_state_file)
    else:
        pretrained_dict = torch.load(model_state_file, map_location='cpu')
    model_dict = model.state_dict()
    pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                        if k[6:] in model_dict.keys()}      # To remove the "model." from state dict
    for k, _ in pretrained_dict.items():
        logger.info(
            '=> loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    if torch.cuda.is_available():
        gpus = list(config.GPUS)
        model = nn.DataParallel(model, device_ids=gpus).cuda()
        device_str = 'cuda'
    else:
        gpus = [-1]
        device_str = 'cpu'

    # prepare data
    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.'+config.DATASET.DATASET)(
                        root=config.DATASET.ROOT,
                        list_path=config.DATASET.TEST_SET,
                        num_samples=None,
                        num_classes=config.DATASET.NUM_CLASSES,
                        multi_scale=False,
                        flip=False,
                        ignore_label=config.TRAIN.IGNORE_LABEL,
                        base_size=config.TEST.BASE_SIZE,
                        crop_size=test_size,
                        downsample_rate=1,
                        return_convergence=True,
                        )

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
    )

    start = timeit.default_timer()
    if 'val' in config.DATASET.TEST_SET:
        mean_IoU, IoU_array, pixel_acc, mean_acc, cvg_rel, cvg_abs = testval(
            config,
            test_dataset,
            testloader,
            model,
            sv_dir=final_output_dir,
            sv_pred=True,
        )

        msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \
            Mean_Acc: {: 4.4f}, Class IoU: '.format(mean_IoU,
            pixel_acc, mean_acc)
        logging.info(msg)
        logging.info(IoU_array)
        if args.results_name is not None:
            write_header = not Path(args.results_name).is_file()
            try:
                perf = mean_IoU.cpu().numpy().item()
            except AttributeError:
                perf = mean_IoU
            df_results = pd.DataFrame({
                'phase': 'eval',
                'miou': perf,
                'cvg_rel': cvg_rel,
                'cvg_abs': cvg_abs,
                'percent': args.percent,
                'opts': ",".join(args.opts),
                'f_thres_val': config.DEQ.F_THRES,
                'dataset': 'cityscapes',
                'model_size': os.path.basename(args.cfg).split('.')[0],
            }, index=[0])
            df_results.to_csv(
                args.results_name,
                mode='a',
                header=write_header,
                index=False,
            )
    elif 'test' in config.DATASET.TEST_SET:
        test(config,
             test_dataset,
             testloader,
             model,
             sv_dir=final_output_dir)

    end = timeit.default_timer()
    logger.info('Mins: %d' % np.int((end-start)/60))
    logger.info('Done')


if __name__ == '__main__':
    main()
