import os
import cv2
import json
import torch
import random
import logging
import argparse
import numpy as np
from PIL import Image
from skimage import measure
from tabulate import tabulate
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
from tqdm import tqdm
import open_clip
from model import LinearLayer
from dataset import VisaDataset, MVTecDataset,OthersDataset
from prompt_ensemble import encode_text_with_prompt_ensemble
from sklearn.cluster import KMeans
import Fuzzy_clip
import copy
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def normalize(pred, max_value=None, min_value=None):
    if max_value is None or min_value is None:
        return (pred - pred.min()) / (pred.max() - pred.min())
    else:
        return (pred - min_value) / (max_value - min_value)


def apply_ad_scoremap(image, scoremap, alpha=0.5):
    np_image = np.asarray(image, dtype=float)
    scoremap = (scoremap * 255).astype(np.uint8)
    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)


def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
    # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
    binary_amaps = np.zeros_like(amaps, dtype=bool)
    min_th, max_th = amaps.min(), amaps.max()
    delta = (max_th - min_th) / max_step
    pros, fprs, ths = [], [], []
    for th in np.arange(min_th, max_th, delta):
        binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
        pro = []
        for binary_amap, mask in zip(binary_amaps, masks):
            for region in measure.regionprops(measure.label(mask)):
                tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
                pro.append(tp_pixels / region.area)
        inverse_masks = 1 - masks
        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
        fpr = fp_pixels / inverse_masks.sum()
        pros.append(np.array(pro).mean())
        fprs.append(fpr)
        ths.append(th)
    pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
    idxes = fprs < expect_fpr
    fprs = fprs[idxes]
    fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
    pro_auc = auc(fprs, pros[idxes])
    return pro_auc

def clustering(text_fea, train_text_prompts):
    data_first_columns = np.array([v[:, 0].cpu().numpy() for v in train_text_prompts.values()])
    
    # 执行 KMeans 聚类
    n_clusters = 2  # 你可以根据需要调整这个值
    kmeans = KMeans(n_clusters=n_clusters, random_state=0,n_init='auto').fit(data_first_columns)

    # 打印每个类的名称和所属数据点
    clusters = {i: [] for i in range(n_clusters)}
    for item, label in zip(train_text_prompts.keys(), kmeans.labels_):
        clusters[label].append(item)

    new_data_point = text_fea[:, 0].cpu().numpy()
    distances = np.linalg.norm(kmeans.cluster_centers_ - new_data_point, axis=1)
    
    # 归一化距离
    total_distance = np.sum(distances)
    normalized_distances = 1/distances
    

    normalized_distances = softmax(normalized_distances)
    # for i, (distance, norm_distance) in enumerate(zip(distances, normalized_distances)):
    #     print(f"Distance to cluster {i} center: {distance}, Normalized distance: {norm_distance}")
    
    return normalized_distances,clusters

def softmax(x):
    e_x = np.exp(x - np.max(x))  # 减去最大值以提高数值稳定性
    return e_x / e_x.sum(axis=0)

def train_data_get_clusters(train_text_prompts):
    data_first_columns = np.array([v[:, 0].cpu().numpy() for v in train_text_prompts.values()])

    # 执行 KMeans 聚类
    n_clusters = 2
    kmeans = KMeans(n_clusters=n_clusters, random_state=0,n_init='auto').fit(data_first_columns)
    # 打印每个类的名称和所属数据点
    clusters = {i: [] for i in range(n_clusters)}
    for item, label in zip(train_text_prompts.keys(), kmeans.labels_):
        clusters[label].append(item)
    return clusters

def test(args):
    img_size = args.image_size
    features_list = args.features_list
    few_shot_features = args.few_shot_features
    test_dir = args.test_data_path
    train_dir = args.train_data_path
    save_path = args.save_path
    dataset_name = args.dataset
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    txt_path = os.path.join(save_path, 'log.txt')

    # clip
    model, _, preprocess =  Fuzzy_clip.create_model_and_transforms(args.model, img_size, pretrained=args.pretrained,design_details=None)
    model.to(device)
    tokenizer = Fuzzy_clip.get_tokenizer(args.model)

    # logger
    root_logger = logging.getLogger()
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    root_logger.setLevel(logging.WARNING)
    logger = logging.getLogger('test')
    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
                                  datefmt='%y-%m-%d %H:%M:%S')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(txt_path, mode='w')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # record parameters
    for arg in vars(args):
        if args.mode == 'zero_shot' and (arg == 'k_shot' or arg == 'few_shot_features'):
            continue
        logger.info(f'{arg}: {getattr(args, arg)}')

    # seg
    with open(args.config_path, 'r') as f:
        model_configs = json.load(f)
    linearlayer_cluster1 = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
                              len(features_list), args.model).to(device)
    checkpoint = torch.load(args.checkpoint_cluster1_path)
    linearlayer_cluster1 .load_state_dict(checkpoint["trainable_linearlayer"])


    linearlayer_cluster2 = LinearLayer(model_configs['vision_cfg']['width'], model_configs['embed_dim'],
                              len(features_list), args.model).to(device)
    checkpoint = torch.load(args.checkpoint_cluster2_path)
    linearlayer_cluster2.load_state_dict(checkpoint["trainable_linearlayer"])
    # dataset
    transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.CenterCrop(img_size),
            transforms.ToTensor()
        ])
    if dataset_name == 'mvtec':
        train_data = VisaDataset(root=train_dir, transform=preprocess, target_transform=transform, mode='test')
        test_data = MVTecDataset(root=test_dir, transform=preprocess, target_transform=transform,
                                 aug_rate=-1, mode='test')
    if dataset_name == 'visa':
        train_data = MVTecDataset(root=train_dir, transform=preprocess, target_transform=transform,
                                 aug_rate=-1, mode='test')
        test_data = VisaDataset(root=test_dir, transform=preprocess, target_transform=transform, mode='test')
    if dataset_name == 'others':
        train_data = VisaDataset(root=train_dir, transform=preprocess, target_transform=transform, mode='test')
        test_data = OthersDataset(root=test_dir, transform=preprocess, target_transform=transform,
                                 aug_rate=-1, mode='test')
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
    test_obj_list = test_data.get_cls_names()
    with torch.cuda.amp.autocast(), torch.no_grad():
        test_text_prompts = encode_text_with_prompt_ensemble(model, test_obj_list, tokenizer, device)
    train_obj_list = train_data.get_cls_names()
    with torch.cuda.amp.autocast(), torch.no_grad():
        train_text_prompts = encode_text_with_prompt_ensemble(model, train_obj_list, tokenizer, device)
    obj_cluster= train_data_get_clusters(train_text_prompts)
    obj_cluster1_str=[', '.join(obj_cluster[0])]
    obj_cluster2_str=[', '.join(obj_cluster[1])]
    with torch.cuda.amp.autocast(), torch.no_grad():
        train_text_cluster1 = encode_text_with_prompt_ensemble(model, obj_cluster1_str, tokenizer, device)[', '.join(obj_cluster[0])]
        train_text_cluster2 = encode_text_with_prompt_ensemble(model, obj_cluster2_str, tokenizer, device)[', '.join(obj_cluster[1])]


    # 创建一个字典来存储 cls_name 对应的特征 fea
    fea_dict = {}

    # 遍历每个 cls_name，计算并存储对应的 fea
    for name in test_obj_list:
        cluster1_combination = [name + ', ' +', '.join(obj_cluster[0])]
        cluster2_combination = [name + ', ' +', '.join(obj_cluster[1])] 
        # 提前计算 fea 并存储在字典中
        with torch.cuda.amp.autocast(), torch.no_grad():
            test_fea_1 = encode_text_with_prompt_ensemble(model, cluster1_combination, tokenizer, device)[name + ', ' +', '.join(obj_cluster[0])]
            test_fea_2 = encode_text_with_prompt_ensemble(model, cluster2_combination, tokenizer, device)[name + ', ' +', '.join(obj_cluster[1])] 
        
        # 将计算好的 fea 存储为字典的值
        fea_dict[name] = (test_fea_1, test_fea_2)

    test_categories_str = ', '.join(test_obj_list)
    test_obj=[test_categories_str]

    # text prompt
    with torch.cuda.amp.autocast(), torch.no_grad():
        test_obj_fea =  encode_text_with_prompt_ensemble(model, test_obj, tokenizer, device)
    test_fea=test_obj_fea[test_categories_str]

    results = {}
    results['cls_names'] = []
    results['imgs_masks'] = []
    results['anomaly_maps'] = []
    results['gt_sp'] = []
    results['pr_sp'] = []
    for items in tqdm(test_dataloader, desc="Processing test dataloader"):
        image = items['img'].to(device)
        cls_name = items['cls_name']
        results['cls_names'].append(cls_name[0])
        gt_mask = items['img_mask']
        gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
        results['imgs_masks'].append(gt_mask)  # px
        results['gt_sp'].append(items['anomaly'].item())
        test_fea_1, test_fea_2 = fea_dict[cls_name[0]]
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features, patch_tokens = model.encode_image(image, features_list)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # weight
            text_probs1 = (100.0 * image_features @ train_text_cluster1).softmax(dim=-1)
            image_score1 = text_probs1[0][1].cpu().item()

            text_probs2 = (100.0 * image_features @ train_text_cluster2).softmax(dim=-1)
            image_score2 = text_probs2[0][1].cpu().item()
            # 合并分数并应用 softmax
            scores = np.array([image_score1, image_score2])
            weights = softmax(scores)
            results['pr_sp'].append(image_score1)

            # pixel
            patch_tokens_1=copy.deepcopy(patch_tokens)
            patch_tokens_1 = linearlayer_cluster1(patch_tokens_1)

            #image_features, patch_tokens = model.encode_image(image, features_list)
            patch_tokens_2=copy.deepcopy(patch_tokens)
            patch_tokens_2 = linearlayer_cluster2(patch_tokens_2)
            anomaly_maps = []
            
            for layer in range(len(patch_tokens)):
                patch_tokens_1[layer] /= patch_tokens_1[layer].norm(dim=-1, keepdim=True)
                anomaly_map_1 = (100.0 * patch_tokens_1[layer] @ test_fea_1)
                B, L, C = anomaly_map_1.shape
                H = int(np.sqrt(L))
                anomaly_map_1 = F.interpolate(anomaly_map_1.permute(0, 2, 1).view(B, 2, H, H),
                                            size=img_size, mode='bilinear', align_corners=True)
                anomaly_map_1 = torch.softmax(anomaly_map_1, dim=1)[:, 1, :, :]
                
                patch_tokens_2[layer] /= patch_tokens_2[layer].norm(dim=-1, keepdim=True)
                anomaly_map_2 = (100.0 * patch_tokens_2[layer] @ test_fea_2)
                B, L, C = anomaly_map_2.shape
                H = int(np.sqrt(L))
                anomaly_map_2 = F.interpolate(anomaly_map_2.permute(0, 2, 1).view(B, 2, H, H),
                                            size=img_size, mode='bilinear', align_corners=True)
                anomaly_map_2 = torch.softmax(anomaly_map_2, dim=1)[:, 1, :, :]

                anomaly_map = weights[0]*anomaly_map_1+weights[1]*anomaly_map_2
                anomaly_maps.append(anomaly_map.cpu().numpy())
            anomaly_map = np.sum(anomaly_maps, axis=0)
            results['anomaly_maps'].append(anomaly_map)
            path = items['img_path']
            cls = path[0].split('/')[-2]
            # filename = path[0].split('/')[-1]
            # vis = cv2.cvtColor(cv2.resize(cv2.imread(path[0]), (img_size, img_size)), cv2.COLOR_BGR2RGB)  # RGB
            # mask = normalize(anomaly_map[0])
            # vis = apply_ad_scoremap(vis, mask)
            # vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)  # BGR
            # save_vis = os.path.join(save_path, 'imgs', cls_name[0], cls)
            # if not os.path.exists(save_vis):
            #     os.makedirs(save_vis)
            # cv2.imwrite(os.path.join(save_vis, filename), vis)

    # metrics
    table_ls = []
    auroc_sp_ls = []
    auroc_px_ls = []
    f1_sp_ls = []
    f1_px_ls = []
    aupro_ls = []
    ap_sp_ls = []
    ap_px_ls = []
    for obj in test_obj_list:
        table = []
        gt_px = []
        pr_px = []
        gt_sp = []
        pr_sp = []
        pr_sp_tmp = []
        table.append(obj)
        for idxes in range(len(results['cls_names'])):
            if results['cls_names'][idxes] == obj:
                gt_px.append(results['imgs_masks'][idxes].squeeze(1).numpy())
                pr_px.append(results['anomaly_maps'][idxes])
                pr_sp_tmp.append(np.max(results['anomaly_maps'][idxes]))
                gt_sp.append(results['gt_sp'][idxes])
                pr_sp.append(results['pr_sp'][idxes])
        gt_px = np.array(gt_px)
        gt_sp = np.array(gt_sp)
        pr_px = np.array(pr_px)
        pr_sp = np.array(pr_sp)

        auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
        auroc_sp = roc_auc_score(gt_sp, pr_sp)
        ap_sp = average_precision_score(gt_sp, pr_sp)
        ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
        # f1_sp
        precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
        f1_scores = (2 * precisions * recalls) / (precisions + recalls)
        f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])
        # f1_px
        precisions, recalls, thresholds = precision_recall_curve(gt_px.ravel(), pr_px.ravel())
        f1_scores = (2 * precisions * recalls) / (precisions + recalls)
        f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
        # aupro
        if len(gt_px.shape) == 4:
            gt_px = gt_px.squeeze(1)
        if len(pr_px.shape) == 4:
            pr_px = pr_px.squeeze(1)
        aupro = cal_pro_score(gt_px, pr_px)

        table.append(str(np.round(auroc_px * 100, decimals=1)))
        table.append(str(np.round(f1_px * 100, decimals=1)))
        table.append(str(np.round(ap_px * 100, decimals=1)))
        table.append(str(np.round(aupro * 100, decimals=1)))
        table.append(str(np.round(auroc_sp * 100, decimals=1)))
        table.append(str(np.round(f1_sp * 100, decimals=1)))
        table.append(str(np.round(ap_sp * 100, decimals=1)))

        table_ls.append(table)
        auroc_sp_ls.append(auroc_sp)
        auroc_px_ls.append(auroc_px)
        f1_sp_ls.append(f1_sp)
        f1_px_ls.append(f1_px)
        aupro_ls.append(aupro)
        ap_sp_ls.append(ap_sp)
        ap_px_ls.append(ap_px)

    # logger
    table_ls.append(['mean', str(np.round(np.mean(auroc_px_ls) * 100, decimals=1)),
                     str(np.round(np.mean(f1_px_ls) * 100, decimals=1)), str(np.round(np.mean(ap_px_ls) * 100, decimals=1)),
                     str(np.round(np.mean(aupro_ls) * 100, decimals=1)), str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),
                     str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])
    results = tabulate(table_ls, headers=['objects', 'auroc_px', 'f1_px', 'ap_px', 'aupro', 'auroc_sp',
                                          'f1_sp', 'ap_sp'], tablefmt="pipe")
    logger.info("\n%s", results)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("FuzzyCLIP", add_help=True)
    # paths
    parser.add_argument("--test_data_path", type=str, default="/home/hyn/workspace/AIAD/VAND-APRIL-GAN/data/mvtec_anomaly_detection/data", help="path to test dataset")
    parser.add_argument("--train_data_path", type=str, default="/home/hyn/workspace/AIAD/VAND-APRIL-GAN/data/visa", help="path to test dataset")
    parser.add_argument("--save_path", type=str, default='./results/mvtec/zero_shot', help='path to save results')
    parser.add_argument("--checkpoint_cluster1_path", type=str, default='/home/hyn/workspace/AIAD/VAND-APRIL-GAN/exps/spilt/obj0.0001/best.pth', help='path to save results')
    parser.add_argument("--checkpoint_cluster2_path", type=str, default='/home/hyn/workspace/AIAD/VAND-APRIL-GAN/exps/spilt/texture/best.pth', help='path to save results')
    parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-L-14-336.json', help="model configs")
    # model
    parser.add_argument("--dataset", type=str, default='mvtec', help="test dataset")
    parser.add_argument("--model", type=str, default="ViT-L-14-336", help="model used")
    parser.add_argument("--pretrained", type=str, default="openai", help="pretrained weight used")
    parser.add_argument("--features_list", type=int, nargs="+", default=[ 6, 12,18,24], help="features used")
    parser.add_argument("--few_shot_features", type=int, nargs="+", default=[3, 6, 9], help="features used for few shot")
    parser.add_argument("--image_size", type=int, default=518, help="image size")
    parser.add_argument("--mode", type=str, default="zero_shot", help="zero shot or few shot")
    # few shot
    parser.add_argument("--k_shot", type=int, default=0, help="e.g., 10-shot, 5-shot, 1-shot")
    parser.add_argument("--seed", type=int, default=10, help="random seed")
    args = parser.parse_args()

    setup_seed(args.seed)
    test(args)
