import sys
import logging
import pprint
import time
import numpy as np
import torch
import os.path as osp
import cv2

from models import build_model
from utils import (
    get_best_model_path, get_last_model_path, setup_logging, mkdir, zipdir
)
from utils.metric import AverageMeter, DiceMetric, RegressMetric, IouMetric
from utils.parser import parse_args, load_config
from dataset import build_data_pipeline
from predict import Predictor
from dataset.cityscapes import CityscapesDataset

logger = logging.getLogger(__name__)


@torch.no_grad()
def perform_test(test_loader, predictor, device, thres=0.5, print_freq=10, save_path=None) -> float:
    batch_time = AverageMeter()
    eval_metric = IouMetric(num_classes=predictor.num_classes,
                            classes=test_loader.dataset.classes,
                            ignore_index=255)

    max_iter = len(test_loader)
    end = time.time()
    for i, samples in enumerate(test_loader):
        inputs, labels = samples[0].to(device), samples[1]
        # forward
        outputs = predictor.model(inputs)
        predicts = predictor.model.act(outputs)
        # update metric
        iou, acc = eval_metric.update(predicts.detach().cpu().numpy(), labels.numpy())
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            log_str = (
                ("Eval[{0}/{1}]\t"
                 "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                 "Acc {acc:.4f}\tIou {iou:.4f}").format(
                     i, max_iter, batch_time=batch_time,
                     acc=acc, iou=iou)
            )
            logger.info(log_str)

        if save_path:
            predicts = predicts.detach().cpu().numpy()
            for j, sample_id in enumerate(samples[2]):
                predict = predicts[j]
                pred_label = np.argmax(predict, axis=0)
                out_image = np.zeros_like(pred_label, dtype=np.uint8)
                labels = np.unique(pred_label)
                for l in labels:
                    out_image[pred_label == l] = CityscapesDataset.IDS[l]
                # save out image
                save_file = osp.join(save_path, "{}predict.png".format(sample_id))
                cv2.imwrite(save_file, out_image)
                # with open(save_file, "wb") as f:
                #     np.savez(f, sample_id=sample_id, predict=predict)

    acc, iou = eval_metric.mean_score()
    log_str = "Eval[{}]\tAcc {:.4f}\tIou {:.4f}".format(
        eval_metric.num_samples, acc, iou)
    logger.info(log_str)
    if eval_metric.num_classes > 1:
        logger.info("Score by classes :")
        eval_metric.print_class_score()

    return acc, iou


def test(cfg):
    # Print config
    logger.info("Test with config:")
    logger.info(pprint.pformat(cfg))

    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    device = torch.device(cfg.DEVICE)
    # Build the model
    if cfg.TEST.CHECKPOINT_PATH:
        model_path = cfg.TEST.CHECKPOINT_PATH
    elif cfg.TEST.MODEL_EPOCH > 0:
        model_path = osp.join(
            cfg.OUTPUT_DIR, "model/checkpoint_epoch_{}.pth".format(cfg.TEST.MODEL_EPOCH)
        )
    elif cfg.TEST.BEST_CHECKPOINT:
        model_path = get_best_model_path(cfg)
    else:
        model_path = get_last_model_path(cfg)

    model = build_model(cfg, model_path=model_path)
    model.to(device)
    model.eval()
    predictor = Predictor(
        model, model.n_classes, device,
        mode=cfg.PREDICT.MODE,
        scales=cfg.PREDICT.SCALES,
        flip=cfg.PREDICT.FLIP
    )

    # Create the test data loader
    test_loader = build_data_pipeline(cfg, "val")

    if cfg.TEST.SAVE_PREDICTS:
        test_loader.dataset.return_id = True
        save_path = osp.join(cfg.OUTPUT_DIR, "test_results")
        mkdir(save_path)
    else:
        save_path = None

    # perform test
    perform_test(test_loader, predictor, device, cfg.THRES, cfg.LOG_PERIOD, save_path)
    if cfg.TEST.SAVE_PREDICTS:
        logger.info("Zipping ... ")
        zipdir(save_path, save_path.rstrip("/") + ".zip")
        logger.info("Mission complete !")


def main():
    args = parse_args()
    cfg = load_config(args)
    mkdir(cfg.OUTPUT_DIR)
    setup_logging(output_dir=cfg.OUTPUT_DIR, level=logging.INFO)
    logger.info("Launch command:")
    logger.info(" ".join(sys.argv))
    test(cfg)


if __name__ == "__main__":
    main()
