import torch
import cv2
import torch.nn as nn
import numpy as np
import random
import os
import json
import argparse
from torch.utils.data import DataLoader
from datetime import datetime
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import logging
from tqdm import tqdm
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
from tabulate import tabulate
import open_clip
from dataset import VisaDataset, MVTecDataset
from model import LinearLayer
from loss import FocalLoss, BinaryDiceLoss
from prompt_ensemble import encode_text_with_prompt_ensemble
from skimage import measure
from sklearn.cluster import KMeans
from sklearn.cluster import DBSCAN
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import SpectralClustering
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
import pandas as pd
def normalize(pred, max_value=None, min_value=None):
    if max_value is None or min_value is None:
        return (pred - pred.min()) / (pred.max() - pred.min())
    else:
        return (pred - min_value) / (max_value - min_value)
def find_best_kmeans_n_clusters(data, min_clusters=1, max_clusters=10):
    best_n_clusters = min_clusters
    best_score = float('inf')  

    for n_clusters in range(min_clusters, max_clusters + 1):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=0).fit(data)
        inertia = kmeans.inertia_
        penalty = np.exp(n_clusters)
        score = inertia + 0.1*penalty  # 加入惩罚项的评估指标
        print(f"Number of clusters: {n_clusters}, Inertia: {inertia}, Penalty: {penalty}, Score: {score}")

        if score < best_score:
            best_score = score
            best_n_clusters = n_clusters

    print(f"Best number of clusters: {best_n_clusters} with Score: {best_score}")

    # 返回最佳的聚类数和对应的KMeans模型
    best_kmeans = KMeans(n_clusters=best_n_clusters, random_state=0).fit(data)
    return best_n_clusters, best_kmeans

def apply_ad_scoremap(image, scoremap, alpha=0.5):
    np_image = np.asarray(image, dtype=float)
    scoremap = (scoremap * 255).astype(np.uint8)
    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
    # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
    binary_amaps = np.zeros_like(amaps, dtype=bool)
    min_th, max_th = amaps.min(), amaps.max()
    delta = (max_th - min_th) / max_step
    pros, fprs, ths = [], [], []
    for th in np.arange(min_th, max_th, delta):
        binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
        pro = []
        for binary_amap, mask in zip(binary_amaps, masks):
            for region in measure.regionprops(measure.label(mask)):
                tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
                pro.append(tp_pixels / region.area)
        inverse_masks = 1 - masks
        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
        fpr = fp_pixels / inverse_masks.sum()
        pros.append(np.array(pro).mean())
        fprs.append(fpr)
        ths.append(th)
    pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
    idxes = fprs < expect_fpr
    fprs = fprs[idxes]
    fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
    pro_auc = auc(fprs, pros[idxes])
    return pro_auc

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train(args):
    # configs
    image_size = args.image_size
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    save_path = args.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    txt_path = os.path.join(save_path, 'log.txt')  # log

    # model configs
    features_list = args.features_list
    with open(args.config_path, 'r') as f:
        model_configs = json.load(f)

    # clip model
    model, _, preprocess = open_clip.create_model_and_transforms(args.model, image_size, pretrained=args.pretrained)
    model.to(device)
    tokenizer = open_clip.get_tokenizer(args.model)

    # logger
    root_logger = logging.getLogger()
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    root_logger.setLevel(logging.WARNING)
    logger = logging.getLogger('train')
    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
                                  datefmt='%y-%m-%d %H:%M:%S')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(txt_path, mode='w')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # record parameters
    for arg in vars(args):
        logger.info(f'{arg}: {getattr(args, arg)}')

    # transforms
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor()
    ])
    
    # datasets
    if args.dataset == 'mvtec':
        train_data = MVTecDataset(root=args.train_data_path, transform=preprocess, target_transform=transform,
                                  aug_rate=args.aug_rate)
        test_data =  VisaDataset(root=args.test_data_path, transform=preprocess, target_transform=transform)
    else:
        train_data = VisaDataset(root=args.train_data_path, transform=preprocess, target_transform=transform)
        test_data =  MVTecDataset(root=args.test_data_path, transform=preprocess, target_transform=transform,
                                  aug_rate=0.0)




    with torch.cuda.amp.autocast(), torch.no_grad():
        train_obj_list = train_data.get_cls_names()
        train_text_prompts = encode_text_with_prompt_ensemble(model, train_obj_list, tokenizer, device)
        data_first_columns = np.array([v[:, 0].cpu().numpy() for v in train_text_prompts.values()])

        # 寻找最佳聚类数
        best_n_clusters, best_kmeans = find_best_kmeans_n_clusters(data_first_columns, min_clusters=1, max_clusters=2)

        # 打印每个类的名称和所属数据点
        clusters = {i: [] for i in range(best_n_clusters)}
        for item, label in zip(train_text_prompts.keys(), best_kmeans.labels_):
            clusters[label].append(item)

        for label, items in clusters.items():
            print(f"train_Cluster {label}: {', '.join(items)}")

        # 获取测试数据并编码
        test_obj_list = test_data.get_cls_names()
        test_text_prompts = encode_text_with_prompt_ensemble(model, test_obj_list, tokenizer, device)

        for label, items in clusters.items():
            print(f"test_Cluster {label}: {', '.join(items)}")
        for key, tensor in test_text_prompts.items():
            new_data_point = tensor[:, 0].cpu().numpy()
            distances = np.linalg.norm(best_kmeans.cluster_centers_ - new_data_point, axis=1)
            print(f"\nDistances from {key} to each cluster center:")
            for i, distance in enumerate(distances):
                print(f"Distance to cluster {i} center: {distance}")
            closest_cluster = np.argmin(distances)
            print(f"{key} is closest to cluster {closest_cluster} center with a distance of {distances[closest_cluster]}")
        n_samples = data_first_columns.shape[0]
        perplexity = min(30, n_samples // 2)
        tsne = TSNE(n_components=2, random_state=0, perplexity=perplexity)
        data_first_columns_2d = tsne.fit_transform(data_first_columns)

        # 计算降维后的聚类中心
        cluster_centers_2d = np.array([data_first_columns_2d[best_kmeans.labels_ == i].mean(axis=0) for i in range(best_n_clusters)])

        # 准备保存到CSV的数据
        labels = np.array(list(train_text_prompts.keys()))
        cluster_labels = best_kmeans.labels_
        tsne_df = pd.DataFrame(data_first_columns_2d, columns=['t-SNE Component 1', 't-SNE Component 2'])
        tsne_df['Label'] = labels
        tsne_df['Cluster'] = cluster_labels

        # 保存到CSV文件
        tsne_df.to_csv('kmeans_tsne_clustering_data.csv', index=False)

        # 绘制聚类后的图
        plt.figure(figsize=(10, 6))
        colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
        markers = ['o', 's', 'D', '^', 'v', '<', '>']

        # 对每个聚类进行处理
        for i in range(best_n_clusters):
            points = data_first_columns_2d[best_kmeans.labels_ == i]
            cluster_labels = np.array(list(train_text_prompts.keys()))[best_kmeans.labels_ == i]
            plt.scatter(points[:, 0], points[:, 1], s=100, c=colors[i % len(colors)], marker=markers[i % len(markers)], label=f'Cluster {i}')
            
            # 在每个点旁边显示标签名称
            for point, label in zip(points, cluster_labels):
                plt.text(point[0], point[1], label, fontsize=9, ha='right')

        # 绘制聚类中心
        plt.scatter(cluster_centers_2d[:, 0], cluster_centers_2d[:, 1], s=300, c='black', marker='x', label='Centroids')

        plt.title('KMeans Clustering with t-SNE')
        plt.xlabel('t-SNE Component 1')
        plt.ylabel('t-SNE Component 2')
        plt.legend()

        # 保存图像
        plt.savefig('kmeans_tsne_clustering.png')





if __name__ == '__main__':
    parser = argparse.ArgumentParser("VAND Challenge", add_help=True)
    # path
    parser.add_argument("--train_data_path", type=str, default="data/mvtec_anomaly_detection/data", help="train dataset path")
    parser.add_argument("--test_data_path", type=str, default="data/visa", help="test dataset path")
    parser.add_argument("--save_path", type=str, default='./exps/cusl', help='path to save results')
    parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-L-14-336.json', help="model configs")
    # model
    parser.add_argument("--dataset", type=str, default='visa', help="train dataset name")
    parser.add_argument("--model", type=str, default="ViT-L-14-336", help="model used")
    parser.add_argument("--pretrained", type=str, default="openai", help="pretrained weight used")
    parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
    # hyper-parameter
    parser.add_argument("--epoch", type=int, default=3, help="epochs")
    parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate")
    parser.add_argument("--batch_size", type=int, default=16, help="batch size")
    parser.add_argument("--image_size", type=int, default=518, help="image size")
    parser.add_argument("--aug_rate", type=float, default=0.2, help="augmentation rate")
    parser.add_argument("--print_freq", type=int, default=1, help="print frequency")
    parser.add_argument("--save_freq", type=int, default=1, help="save frequency")
    args = parser.parse_args()


    setup_seed(111)
    train(args)

