import cv2
import os
import torch
import PIL.Image as Image
from spikingjelly.activation_based import functional, neuron

from dataset import (
    restore_bsds,
    crop_bsds,
    pad_nyud,
    restore_nyud,
)
import time
from tqdm import tqdm
import scipy.io as sio
import torch.nn.functional as F
import numpy as np


def test_bsds(test_loader, model, save_dir, logger, device, multi_scale=True):
    # folder
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # test
    model.eval()

    assert test_loader.batch_size == 1

    """--------------------single_test---------------------"""
    """ single """
    t_time = 0
    t_duration = 0
    length = test_loader.dataset.__len__()

    """single_png"""
    single_png_dir = os.path.join(save_dir, "single_png")
    if not os.path.isdir(single_png_dir):
        os.makedirs(single_png_dir)
    """single_mat"""
    single_mat_dir = os.path.join(save_dir, "single_mat")
    if not os.path.isdir(single_mat_dir):
        os.makedirs(single_mat_dir)

    model.module.add_hooks(neuron.BaseNode)
    for batch_index, data in enumerate(tqdm(test_loader, ncols=200)):
        is_transpose = False
        with (torch.no_grad()):
            images = data["images"].to(device)
            if images.shape[2] == 481:
                images = torch.transpose(images, 2, 3)
                is_transpose = True
            images = crop_bsds(images)  # BSDS500

            start_time = time.time()
            height, width = images.shape[2:]
            preds = model(images)[0].sigmoid()
            functional.reset_net(model)
            preds = preds.detach().cpu().numpy().squeeze()  # H*W

            # spasitys
            model.module.process_nz(model.module.get_nz_numel())
            model.module.reset_nz_numel()

            if is_transpose:
                preds = np.transpose(preds, (1, 0))
                height, width = preds.shape[0:]
                is_transpose = False
            duration = time.time() - start_time
            t_time += duration
            t_duration += 1 / duration
            preds = restore_bsds(preds, height + 1, width + 1)  # BSDS500
            name = test_loader.dataset.lbl_list[batch_index]
            sio.savemat(
                os.path.join(single_mat_dir, "{}.mat".format(name)), {"result": preds}
            )
            Image.fromarray((preds * 255).astype(np.uint8)).save(
                os.path.join(single_png_dir, "{}.png".format(name))
            )
    logger.info(
        f"single test:\t avg_time: {t_time / length:.3f}, avg_FPS: {t_duration / length:.3f}, "
        + f"Total sparsity: {model.module.all_nnz} / {model.module.all_nnumel} ({100 * model.module.all_nnz / (model.module.all_nnumel + 1e-3):.2f}%)"
    )
    # 重置稀疏度统计
    model.module.all_nnz, model.module.all_nnumel = 0, 0

    if multi_scale:
        """-----------------------------multi_test-------------------------------"""
        """multi"""
        t_time = 0
        t_duration = 0

        """multi_png"""
        multi_png_dir = os.path.join(save_dir, "multi_png")
        if not os.path.isdir(multi_png_dir):
            os.makedirs(multi_png_dir)
        """multi_mat"""
        multi_mat_dir = os.path.join(save_dir, "multi_mat")
        if not os.path.isdir(multi_mat_dir):
            os.makedirs(multi_mat_dir)

        for batch_index, data in enumerate(tqdm(test_loader)):
            with torch.no_grad():
                images = data["images"]
                images = crop_bsds(images)

                height, width = images.shape[2:]
                images_2x = F.interpolate(
                    images, scale_factor=2, mode="bilinear", align_corners=True
                )
                images_half = F.interpolate(
                    images, scale_factor=0.5, mode="bilinear", align_corners=True
                )

                start_time = time.time()

                images = images.to(device)
                preds = model(images)[-1]

                images_2x = images_2x.to(device)
                preds_2x = model(images_2x)[-1]
                preds_2x_down = F.interpolate(
                    preds_2x, size=(height, width), mode="bilinear", align_corners=True
                )

                images_half = images_half.to(device)
                preds_half = model(images_half)[-1]
                preds_half_up = F.interpolate(
                    preds_half,
                    size=(height, width),
                    mode="bilinear",
                    align_corners=True,
                )

                fuse_final = (preds + preds_2x_down + preds_half_up) / 3
                fuse_final = fuse_final.cpu().detach().numpy().squeeze()
                duration = time.time() - start_time
                t_time += duration
                t_duration += 1 / duration
                fuse_final = restore_bsds(fuse_final, height + 1, width + 1)
                name = test_loader.dataset.lbl_list[batch_index]
                sio.savemat(
                    os.path.join(multi_mat_dir, "{}.mat".format(name)),
                    {"result": fuse_final},
                )
                Image.fromarray((fuse_final * 255).astype(np.uint8)).save(
                    os.path.join(multi_png_dir, "{}.png".format(name))
                )
        logger.info(
            "multi test:\t avg_time: {:.3f}, avg_FPS: {:.3f}".format(
                t_time / length, t_duration / length
            )
        )


def test_nyud(test_loader, model, save_dir, logger, device):
    # 创建目录
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # test
    model.eval()

    assert test_loader.batch_size == 1

    t_time = 0
    t_duration = 0
    length = test_loader.dataset.__len__()

    """png"""
    png_dir = os.path.join(save_dir, "png")
    if not os.path.isdir(png_dir):
        os.makedirs(png_dir)
    """mat"""
    mat_dir = os.path.join(save_dir, "mat")
    if not os.path.isdir(mat_dir):
        os.makedirs(mat_dir)

    for batch_index, data in enumerate(tqdm(test_loader)):
        with torch.no_grad():
            images = data["images"].to(device)
            images = pad_nyud(images)  # NYUD

            start_time = time.time()
            height, width = images.shape[2:]
            preds = model(images)[-1]
            preds = preds.detach().cpu().numpy().squeeze()  # H*W
            duration = time.time() - start_time
            t_time += duration
            t_duration += 1 / duration
            preds = restore_nyud(preds, height - 7, width)  # NYUD
            name = test_loader.dataset.lbl_list[batch_index]
            sio.savemat(os.path.join(mat_dir, "{}.mat".format(name)), {"result": preds})
            Image.fromarray((preds * 255).astype(np.uint8)).save(
                os.path.join(png_dir, "{}.png".format(name))
            )
    logger.info(
        "single test:\t avg_time: {:.3f}, avg_FPS: {:.3f}".format(
            t_time / length, t_duration / length
        )
    )
