import os
import cv2
import json
import torch
import random
import logging
import argparse
import numpy as np
from PIL import Image
from skimage import measure
from tabulate import tabulate
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
from tqdm import tqdm
import open_clip
from model import LinearLayer
from dataset import VisaDataset, MVTecDataset,OthersDataset
from prompt_ensemble import encode_text_with_prompt_ensemble
from sklearn.cluster import KMeans
import Fuzzy_clip
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def normalize(pred, max_value=None, min_value=None):
    if max_value is None or min_value is None:
        return (pred - pred.min()) / (pred.max() - pred.min())
    else:
        return (pred - min_value) / (max_value - min_value)


def apply_ad_scoremap(image, scoremap, alpha=0.5):
    np_image = np.asarray(image, dtype=float)
    scoremap = (scoremap * 255).astype(np.uint8)
    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)


def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
    # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
    binary_amaps = np.zeros_like(amaps, dtype=bool)
    min_th, max_th = amaps.min(), amaps.max()
    delta = (max_th - min_th) / max_step
    pros, fprs, ths = [], [], []
    for th in np.arange(min_th, max_th, delta):
        binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
        pro = []
        for binary_amap, mask in zip(binary_amaps, masks):
            for region in measure.regionprops(measure.label(mask)):
                tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
                pro.append(tp_pixels / region.area)
        inverse_masks = 1 - masks
        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
        fpr = fp_pixels / inverse_masks.sum()
        pros.append(np.array(pro).mean())
        fprs.append(fpr)
        ths.append(th)
    pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
    idxes = fprs < expect_fpr
    fprs = fprs[idxes]
    fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
    pro_auc = auc(fprs, pros[idxes])
    return pro_auc


def test(args):
    img_size = args.image_size
    features_list = args.features_list
    test_dir = args.test_data_path
    train_dir = args.train_data_path
    save_path = args.save_path
    dataset_name = args.dataset
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    txt_path = os.path.join(save_path, 'log.txt')

    # clip
    model, _, preprocess =  Fuzzy_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained,design_details=None)
    model.to(device)
    tokenizer = Fuzzy_clip.get_tokenizer(args.model)

    # logger
    root_logger = logging.getLogger()
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    root_logger.setLevel(logging.WARNING)
    logger = logging.getLogger('test')
    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
                                  datefmt='%y-%m-%d %H:%M:%S')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(txt_path, mode='w')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # record parameters
    for arg in vars(args):
        if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'):
            continue
        logger.info(f'{arg}: {getattr(args, arg)}')

    # seg
    with open(args.config_path, 'r') as f:
        model_configs = json.load(f)
    linearlayer = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
                              len(features_list), args.model).to(device)
    checkpoint = torch.load(args.checkpoint_path)
    linearlayer .load_state_dict(checkpoint["trainable_linearlayer"])

    # dataset
    transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor()
        ])
    if dataset_name == 'mvtec':
        train_data = VisaDataset(root=train_dir, transform=preprocess, target_transform=transform, mode='test')
        test_data = MVTecDataset(root=test_dir, transform=preprocess, target_transform=transform,
                                 aug_rate=-1, mode='test')
    if dataset_name == 'visa':
        train_data = MVTecDataset(root=train_dir, transform=preprocess, target_transform=transform,
                                 aug_rate=-1, mode='test')
        test_data = VisaDataset(root=test_dir, transform=preprocess, target_transform=transform, mode='test')
    if dataset_name == 'others':
        train_data = VisaDataset(root=train_dir, transform=preprocess, target_transform=transform, mode='test')
        test_data = OthersDataset(root=test_dir, transform=preprocess, target_transform=transform,
                                 aug_rate=-1, mode='test')
    fea_dict={}

    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
    test_obj_list = test_data.get_cls_names()
    train_obj_list = train_data.get_cls_names()
    for name in test_obj_list:
        cluster1_combination = [name + ', ' +', '.join(train_obj_list)]
        # 提前计算 fea 并存储在字典中
        with torch.cuda.amp.autocast(), torch.no_grad():
            test_fea_1 = encode_text_with_prompt_ensemble(model, cluster1_combination, tokenizer, device)[name + ', ' +', '.join(train_obj_list)]
        
        # 将计算好的 fea 存储为字典的值
        fea_dict[name] = test_fea_1
    results = {}
    results['cls_names'] = []
    results['imgs_masks'] = []
    results['anomaly_maps'] = []
    results['gt_sp'] = []
    results['pr_sp'] = []
    for items in tqdm(test_dataloader, desc="Processing test dataloader"):
        image = items['img'].to(device)
        cls_name = items['cls_name']
        results['cls_names'].append(cls_name[0])
        gt_mask = items['img_mask']
        gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
        results['imgs_masks'].append(gt_mask)  # px
        results['gt_sp'].append(items['anomaly'].item())
        test_fea= fea_dict[cls_name[0]]
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features, patch_tokens = model.encode_image(image, features_list)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # pixel
            patch_tokens_obj=patch_tokens
            patch_tokens_obj = linearlayer(patch_tokens_obj)

            anomaly_maps = []
            
            for layer in range(len(patch_tokens)):
                patch_tokens_obj[layer] /= patch_tokens_obj[layer].norm(dim=-1, keepdim=True)
                anomaly_map = (100.0 * patch_tokens_obj[layer] @ test_fea)
                B, L, C = anomaly_map.shape
                H = int(np.sqrt(L))
                anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H),
                                            size=img_size, mode='bilinear', align_corners=True)
                anomaly_map= torch.softmax(anomaly_map, dim=1)[:, 1, :, :]
                anomaly_maps.append(anomaly_map.cpu().numpy())
            anomaly_map = np.sum(anomaly_maps, axis=0)
            results['anomaly_maps'].append(anomaly_map)
            # path = items['img_path']
            # cls = path[0].split('/')[-2]
            # filename = path[0].split('/')[-1]
            # vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB)  # RGB
            # mask = normalize(anomaly_map[0])
            # vis = apply_ad_scoremap(vis, mask)
            # vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)  # BGR
            # save_vis = os.path.join(save_path, 'imgs', cls_name[0], cls)
            # if not os.path.exists(save_vis):
            #     os.makedirs(save_vis)
            # cv2.imwrite(os.path.join(save_vis, filename), vis)

    # metrics
    table_ls = []
    auroc_sp_ls = []
    auroc_px_ls = []
    f1_sp_ls = []
    f1_px_ls = []
    aupro_ls = []
    ap_sp_ls = []
    ap_px_ls = []
    for obj in test_obj_list:
        table = []
        gt_px = []
        pr_px = []
        table.append(obj)
        for idxes in range(len(results['cls_names'])):
            if results['cls_names'][idxes] == obj:
                gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy())
                pr_px.append(results['anomaly_maps'][idxes])

        gt_px = np.array(gt_px)
        pr_px = np.array(pr_px)
        auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())

        ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
        # f1_px
        precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel())
        f1_scores = (2 * precisions * recalls) / (precisions + recalls)
        f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
        # aupro
        if len(gt_px.shape) == 4:
            gt_px = gt_px.squeeze(1)
        if len(pr_px.shape) == 4:
            pr_px = pr_px.squeeze(1)
        aupro = cal_pro_score(gt_px, pr_px)

        table.append(str(np.round(auroc_px * 100, decimals=1)))
        table.append(str(np.round(f1_px * 100, decimals=1)))
        table.append(str(np.round(ap_px * 100, decimals=1)))
        table.append(str(np.round(aupro * 100, decimals=1)))


        table_ls.append(table)
        auroc_px_ls.append(auroc_px)
        f1_px_ls.append(f1_px)
        aupro_ls.append(aupro)
        ap_px_ls.append(ap_px)

    # logger
    table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)),
                     str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)),
                     str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),
                     str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])
    results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', ], tablefmt="pipe")
    logger.info("\n%s", results)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("FuzzyCLIP", add_help=True)
    # paths
    parser.add_argument("--test_data_path", type=str, default="/home/hyn/workspace/AIAD/VAND-APRIL-GAN/data/visa", help="path to test dataset")
    parser.add_argument("--train_data_path", type=str, default="/home/hyn/workspace/AIAD/VAND-APRIL-GAN/data/mvtec_anomaly_detection/data", help="path to test dataset")
    parser.add_argument("--save_path", type=str, default='./results/mvtec/zero_shot', help='path to save results')
    parser.add_argument("--checkpoint_path", type=str, default='/home/hyn/workspace/AIAD/VAND-APRIL-GAN/beifen_model/visa/best.pth', help='path to save results')
    parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-L-14-336.json', help="model configs")
    # model
    parser.add_argument("--dataset", type=str, default='mvtec', help="test dataset")
    parser.add_argument("--model", type=str, default="ViT-L-14-336", help="model used")
    parser.add_argument("--pretrained", type=str, default="openai", help="pretrained weight used")
    parser.add_argument("--features_list", type=int, nargs="+", default=[ 6, 12,18,24], help="features used")
    parser.add_argument("--image_size", type=int, default=518, help="image size")
    parser.add_argument("--mode", type=str, default="zero_shot", help="zero shot or few shot")
    # few shot
    parser.add_argument("--k_shot", type=int, default=0, help="e.g., 10-shot, 5-shot, 1-shot")
    parser.add_argument("--seed", type=int, default=10, help="random seed")
    args = parser.parse_args()

    setup_seed(args.seed)
    test(args)
