# System libs
import os
import argparse
from distutils.version import LooseVersion
# Numerical libs
import numpy as np
import torch
import torch.nn as nn
from scipy.io import loadmat
import csv
# Our libs
from mit_semseg.dataset import TestDataset,TestDataset2
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode, find_recursive, setup_logger
from mit_semseg.lib.nn import user_scattered_collate, async_copy_to
from mit_semseg.lib.utils import as_numpy
from PIL import Image
from tqdm import tqdm
from mit_semseg.config import cfg

colors = loadmat('data/color150.mat')['colors']
names = {}
with open('data/object150_info.csv') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        names[int(row[0])] = row[5].split(";")[0]


def visualize_result(infos, preds, cfg):
    # (img, info) = data
    for info,pred in zip(infos,preds):
        # print predictions in descending order
        pred = np.int32(pred)
        pixs = pred.size
        uniques, counts = np.unique(pred, return_counts=True)
        print("Predictions in [{}]:".format(info))
        has_person = 0
        for idx in np.argsort(counts)[::-1]:
            name = names[uniques[idx] + 1]
            if name == 'person':
                has_person = 1
            ratio = counts[idx] / pixs * 100
            if ratio > 0.01:
                print("  {}: {:.2f}%".format(name, ratio))


        img_name = info.split('/')[-1]
        # print(info,flush=True)
        # save_name = cfg.save_folder + img_name.split('.')[0] + ".txt"
        # dirname = os.path.dirname(save_name)
        # if not os.path.isdir(dirname):
        #     os.makedirs(dirname)
        # with open(save_name, "w") as fd:
        #     fd.write(str(has_person) + " \n")
        #     for idx in np.argsort(counts)[::-1]:
        #         name = names[uniques[idx] + 1]
        #         ratio = counts[idx] / pixs * 100
        #         if ratio > 0.1:
        #             fd.write("  {}: {:.2f}%".format(name, ratio))
        #     if has_person:
        #         touch(os.path.join(args.save_folder, 'person', img_name))
        #     else:
        #         touch(os.path.join(args.save_folder, 'noperson', img_name ))

def touch(fname, times=None):
    with open(fname, 'a'):
        os.utime(fname, times)

def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data
        # print(len(batch_data))
        segSize = (batch_data[0]['img_data'].shape[2],
                   batch_data[0]['img_data'].shape[3])
        # img_resized_list = [batch_data[j]['img_data'][0][0] for j in range(len(batch_data))]
        # print(batch_data[0]['img_data'][0].shape)
        # print(img_resized_list)
        print(batch_data[0].keys())
        with torch.no_grad():
            # scores = torch.zeros(len(batch_data), cfg.DATASET.num_class, segSize[0], segSize[1])
            # scores = async_copy_to(scores, gpu)

            feed_dict = batch_data[0].copy()
            feed_dict['img_data'] = batch_data[0]['img_data']
            del feed_dict['paths']
            feed_dict = async_copy_to(feed_dict, gpu)

            # forward pass
            scores = segmentation_module(feed_dict, segSize=segSize)
            print(scores.shape)

            # for img in img_resized_list:
            #     feed_dict = batch_data[0].copy()
            #     feed_dict['img_data'] = img
            #     del feed_dict['img_ori']
            #     del feed_dict['info']
            #     feed_dict = async_copy_to(feed_dict, gpu)
            #
            #     # forward pass
            #     pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            #     scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        print(pred.shape)
        # visualization
        visualize_result(batch_data[0]['paths'],
            pred,
            cfg
        )
        pbar.update(1)

@torch.no_grad()
def main(cfg, gpu):
    torch.cuda.set_device(gpu)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder,
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder,
        use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    # Dataset and Loader
    dataset_test = TestDataset2(
        cfg.list_test,
        cfg.DATASET, bs=16)
    loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=1, # cfg.TEST.batch_size,
        shuffle=False,
        collate_fn=user_scattered_collate,
        num_workers=8,
        drop_last=False)

    segmentation_module.cuda()

    # Main loop
    test(segmentation_module, loader_test, gpu)

    print('Inference done!')


if __name__ == '__main__':
    assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
        'PyTorch>=0.4.0 is required'

    parser = argparse.ArgumentParser(
        description="PyTorch Semantic Segmentation Testing"
    )
    parser.add_argument(
        "--imgs",default='',
        type=str,
        help="an image path, or a directory name"
    )
    parser.add_argument(
        "--img_list",default='',
        type=str,
        help="an image path, or a directory name"
    )
    parser.add_argument(
        "--cfg",
        default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "--gpu",
        default=0,
        type=int,
        help="gpu id for evaluation"
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()

    cfg.merge_from_file(args.cfg)
    cfg.merge_from_list(args.opts)
    # cfg.freeze()
    cfg.DATASET.imgSizes = (320,)
    cfg.DATASET.imgMaxSize = 400
    logger = setup_logger(distributed_rank=0)   # TODO
    logger.info("Loaded configuration file {}".format(args.cfg))
    logger.info("Running with config:\n{}".format(cfg))

    cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower()
    cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower()

    # absolute paths of model weights
    cfg.MODEL.weights_encoder = os.path.join(
        cfg.DIR, 'encoder_' + cfg.TEST.checkpoint)
    cfg.MODEL.weights_decoder = os.path.join(
        cfg.DIR, 'decoder_' + cfg.TEST.checkpoint)

    assert os.path.exists(cfg.MODEL.weights_encoder) and \
        os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"

    # generate testing image list
    if args.imgs !='':
        if os.path.isdir(args.imgs):
            imgs = find_recursive(args.imgs)
        else:
            imgs = [args.imgs]
    if args.img_list != '':
        # f_list = open(args.img_list, 'r')
        with open(args.img_list) as f:
            imgs = f.read().splitlines()
    # imgs = f_list.readlines()
    assert len(imgs), "imgs should be a path to image (.jpg) or directory."
    cfg.list_test = [{'fpath_img': x} for x in imgs]

    if not os.path.isdir(cfg.TEST.result):
        os.makedirs(cfg.TEST.result)

    main(cfg, args.gpu)
