import sys
import logging
import pprint
import time
import numpy as np
import torch
import os.path as osp
import cv2

from models import build_model
from utils import (
    get_best_model_path, get_last_model_path, setup_logging, mkdir
)
from utils.metric import AverageMeter, DiceMetric, RegressMetric
from utils.parser import parse_args, load_config
from dataset import build_data_pipeline

logger = logging.getLogger(__name__)


def zipdir(path, result_path):
    import zipfile
    import os
    zipf = zipfile.ZipFile(result_path, "w")
    for root, dirs, files in os.walk(path):
        for file in files:
            zipf.write(
                osp.join(root, file),
                osp.relpath(osp.join(root, file), osp.join(path, '..'))
            )
    zipf.close()


@torch.no_grad()
def perform_test(test_loader, model, device, thres=0.5, print_freq=10, save_path=None) -> float:
    batch_time = AverageMeter()
    eval_metric = DiceMetric(thres=thres, num_classes=model.n_classes)
    if model.branch_number == 2:
        reg_metric = RegressMetric()

    if save_path:
        all_predicts = None

    model.eval()

    max_iter = len(test_loader)
    end = time.time()
    for i, samples in enumerate(test_loader):
        inputs, labels = samples[0].to(device), samples[1].to(device)
        # forward
        outputs = model(inputs)
        if model.branch_number == 1:
            predicts = model.act(outputs)
        else:
            predicts, aux_predicts = model.act(outputs[0]), model.reg_act(outputs[1])
        # update metric
        score = eval_metric.update(predicts.detach(), labels.detach())
        if model.branch_number == 2:
            reg_score = reg_metric.update(aux_predicts.detach(), samples[2].to(device))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            log_str = (
                ("Eval[{0}/{1}]\t"
                 "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                 "Score {score:.3f}").format(
                     i, max_iter, batch_time=batch_time, score=score)
            )
            if model.branch_number == 2:
                log_str += "\tReg Score {:.4f}".format(reg_score)
            logger.info(log_str)

        if save_path:
            sample_ids = samples[-1]
            pred_label = (predicts.squeeze(dim=1) > thres).int().cpu().numpy()
            for j, sample_id in enumerate(sample_ids):
                out_image = (pred_label[j] * 255).astype(np.uint8)
                out_file = osp.join(save_path, sample_id + ".png")
                cv2.imwrite(out_file, out_image)
            # if all_predicts is None:
            #     all_predicts = predicts.detach().cpu().numpy()
            # else:
            #     all_predicts = np.concatenate((all_predicts, predicts.detach().cpu().numpy()), axis=0)
    mean_score = eval_metric.mean_score()
    log_str = "Eval[{}]\tScore {:.4f}".format(eval_metric.num_samples, mean_score)
    if model.branch_number == 2:
        log_str += "\tReg Score {}".format(str(reg_metric))
    logger.info(log_str)
    if eval_metric.num_classes > 1:
        logger.info("Score by classes :")
        eval_metric.print_class_score()

    if save_path:
        zip_path = save_path.rstrip("/") + ".zip"
        zipdir(save_path, zip_path)
        logger.info("Predictions saved to {}".format(zip_path))
    #     with open(save_path, "wb") as f:
    #         np.save(f, all_predicts)

    return mean_score


def test(cfg):
    # Print config
    logger.info("Test with config:")
    logger.info(pprint.pformat(cfg))

    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    device = torch.device(cfg.DEVICE)
    # Build the model
    if cfg.TEST.CHECKPOINT_PATH:
        model_path = cfg.TEST.CHECKPOINT_PATH
    elif cfg.TEST.MODEL_EPOCH > 0:
        model_path = osp.join(
            cfg.OUTPUT_DIR, "model/checkpoint_epoch_{}.pth".format(cfg.TEST.MODEL_EPOCH)
        )
    elif cfg.TEST.BEST_CHECKPOINT:
        model_path = get_best_model_path(cfg)
    else:
        model_path = get_last_model_path(cfg)

    model = build_model(cfg, model_path=model_path)
    model.to(device)

    # Create the test data loader
    test_loader = build_data_pipeline(cfg, "test")
    test_loader.dataset.return_id = True

    if cfg.TEST.SAVE_PREDICTS:
        save_path = osp.join(cfg.OUTPUT_DIR, "test_predicts")
        mkdir(save_path)
    else:
        save_path = None

    # perform test
    perform_test(test_loader, model, device, cfg.THRES, cfg.LOG_PERIOD, save_path)


def main():
    args = parse_args()
    cfg = load_config(args)
    mkdir(cfg.OUTPUT_DIR)
    setup_logging(output_dir=cfg.OUTPUT_DIR, level=logging.INFO)
    logger.info("Launch command:")
    logger.info(" ".join(sys.argv))
    test(cfg)


if __name__ == "__main__":
    main()
