#!/usr/bin/python3
# coding=utf-8
import os
import sys

from metrics.miouAndCSCS import SegmentationMetric

sys.path.insert(0, '/')
sys.dont_write_bytecode = True
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
import torch.nn.functional as F

import argparse

plt.ion()
import torch

from net_sam_APG import SAM_APG as network
from dataloaders import make_data_loader


def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=24, type=int)
    parser.add_argument('--workers', type=int, default=4,
                        metavar='N', help='dataloader threads')
    parser.add_argument('--dataset', type=str, default='pascal',
                        help='dataset name (default: pascal)')
    parser.add_argument('--snapshot', type=str, default=None,
                        help='set the checkpoint name')
    # parser.add_argument('--base-size', type=int, default=352,
    #                     help='base image size')
    parser.add_argument('--crop-size', type=int, default=352,
                        help='crop image size')

    parser.parse_args()
    return parser.parse_args()


def test(Network):
    ## dataset
    args = parser()
    cfg = args
    # Define Dataloader
    kwargs = {'num_workers': args.workers, 'pin_memory': True}
    train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)

    ## network
    net = Network(cfg)
    net.train(False)
    net.cuda()
    with torch.no_grad():
        metric = SegmentationMetric(3) 
        for step, sample in enumerate(val_loader):
            if step % 100 == 0:
                print('step:', step)

            image, target = sample[0]['image'], sample[0]['label']
            image, target = image.cuda(), target.cuda()

            coarse_map, Background_outputs, sod_outputs,cod_outputs= net(image)

            Background_outputs = Background_outputs['masks']  
            sod_outputs = sod_outputs['masks']  
            cod_outputs = cod_outputs['masks']  
            
            logits = torch.cat((Background_outputs, sod_outputs, cod_outputs), dim=1) 
            logits = F.softmax(logits, dim=1)  
            imgPredict = logits.data.max(1)[1].cpu().numpy()  


            imgPredict = torch.from_numpy(imgPredict).cuda()
            imgPredict = imgPredict.long().cpu()
            target = target.long().cpu()

            hist = metric.addBatch(imgPredict, target, ignore_labels='')  


        print(hist.shape)
        pa = metric.pixelAccuracy()
        cpa = metric.classPixelAccuracy()
        mpa = metric.meanPixelAccuracy()
        IoU = metric.IntersectionOverUnion()
        mIoU = metric.meanIntersectionOverUnion()
        MCA = metric.meanCODSODConfusionAccuracy()
        MCA2 = metric.meanCODSODConfusionAccuracy2()
        MCA3=  metric.meanCODSODConfusionAccuracy3()
        CS = metric.printColSum()


        print('hist is :\n', hist)
        print('PA is : %f' % pa)
        print('cPA is :', cpa)  
        print('mPA is : %f' % mpa)
        print('IoU is : ', IoU)
        print('mIoU is : ', mIoU)
        print('MCA is : ', MCA)
        print('MCA2 is : ', MCA2)
        print('MCA3 is : ', MCA3)
        print('CS is : ', CS)









if __name__ == '__main__':
    test(network)
