import os.path
from torchvision.ops import nms
import glob
import json
import os
import shutil
import operator
import sys
import argparse
import math
import numpy as np
from ARFM import ARFM
import torch
from Testdata import MyDataset
from torch.utils.data import DataLoader
from tqdm import tqdm

def evaluation(net, dataloader):
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    classes = {
        0: "Flat",
        1: "Pedicle",
        2: "Edge"
    }
    for (
            image,
            ssw,
            label,
            name
    ) in tqdm(dataloader, f"Test {len(dataloader)}"):
        all_scores = []
        all_dets = []
        image, ssw, label = (
            image.to(DEVICE),
            ssw.to(DEVICE),
            label.to(DEVICE)
        )
        detection_scores, classification_scores, whole = net(image, ssw)
        detection_scores, classification_scores = detection_scores.cpu().detach().numpy(), classification_scores.cpu().detach().numpy()
        ssw = ssw.squeeze(0)
        for j in range(classification_scores.shape[1]):
            # np.where()[0] 表示行索引,np.where()[1]表示列索引
            inds = np.where((classification_scores[:, j] >= 0.8))[0]  # 过滤概率大于80%
            cls_scores = classification_scores[inds, j]  # 保留框的类别得分
            det_scores = detection_scores[inds, j]  # 保留框的区域得分
            cls_boxes = ssw[inds, :].cpu().detach().numpy()
            # np.newaxis增加一个维度   4-->(4,1)  (x1,y1,x2,y2,class_score,region_score)
            cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis], det_scores[:, np.newaxis])) \
                .astype(np.float32, copy=False)
            keep = nms(torch.tensor(cls_boxes), torch.tensor(cls_scores), 0.5)
            cls_dets = cls_dets[keep, :]
            num = int(cls_dets.size / 6)
            if num == 1:
                all_scores.append(cls_dets[-1])
                cls_dets = cls_dets[np.newaxis, :]
            else:
                for i in range(num):
                    all_scores.append(cls_dets[i, -1])
            all_dets.append(cls_dets)

        # Limit to max_per_image detections *over all classes*
        image_thresh = 1 / len(ssw)
        for j in range(3):
            keep = np.where(all_dets[j][:, -1] >= image_thresh)[0]
            all_dets[j] = all_dets[j][keep, :]
        for index, elements in enumerate(all_dets):
            if len(elements) > 0:
                pre_cls = classes[index]
                for element in elements:
                    sentence = pre_cls + " " + str(element[-1]) + " " + str(element[0]) + " " + str(
                        element[1]) + " " + str(element[2]) + " " + str(element[3]) + '\n'
                    print(sentence)


if __name__ == '__main__':
    classes = {
        0: "Flat",
        1: "Pedicle",
        2: "Edge"
    }
    DEVICE = "cpu"
    datasets = ['clinicdb', 'kvasir', 'private']
    strategies = ['samAndssw', 'samFilterAndssw']
    for dataset in datasets:
        print(
            '------------------------------------------------------------------{}--------------------------------------------------------------'.format(
                dataset))
        data_root = os.path.join('data', dataset)
        net = ARFM(128)
        net.to(DEVICE)
        for strategy in strategies:
            test_ds = MyDataset(data_root, 'test', strategy)
            test_dl = DataLoader(test_ds, batch_size=1, shuffle=False)
            print('_________________________________{}___________________________________'.format(strategy))
            params_path = os.path.join('params', strategy, dataset, strategy + '.pth')
            if DEVICE == "cpu":
                net.load_state_dict(torch.load(params_path, map_location='cpu'))
            else:
                net.load_state_dict(torch.load(params_path))
            net.eval()
            evaluation(net, test_dl)
