import torch
import cv2
import torch.nn as nn
import numpy as np
import random
import os
import json
import argparse
from torch.utils.data import DataLoader
from datetime import datetime
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import logging
import Fuzzy_clip
import warnings
from tqdm import tqdm
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
from tabulate import tabulate
import open_clip
from dataset import VisaDataset, MVTecDataset
from model import LinearLayer
from loss import FocalLoss, BinaryDiceLoss
from prompt_ensemble import encode_text_with_prompt_ensemble
from skimage import measure
from learnable_prompt_ensemble import FuzzyCLIP_PromptLearner

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
def normalize(pred, max_value=None, min_value=None):
    if max_value is None or min_value is None:
        return (pred - pred.min()) / (pred.max() - pred.min())
    else:
        return (pred - min_value) / (max_value - min_value)


def apply_ad_scoremap(image, scoremap, alpha=0.5):
    np_image = np.asarray(image, dtype=float)
    scoremap = (scoremap * 255).astype(np.uint8)
    scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
    scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
    return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)

def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
    binary_amaps = np.zeros_like(amaps, dtype=bool)
    min_th, max_th = amaps.min(), amaps.max()
    delta = (max_th - min_th) / max_step
    pros, fprs, ths = [], [], []
    for th in np.arange(min_th, max_th, delta):
        binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
        pro = []
        for binary_amap, mask in zip(binary_amaps, masks):
            for region in measure.regionprops(measure.label(mask)):
                tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
                pro.append(tp_pixels / region.area)
        inverse_masks = 1 - masks
        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
        fpr = fp_pixels / inverse_masks.sum()
        pros.append(np.array(pro).mean())
        fprs.append(fpr)
        ths.append(th)
    pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
    idxes = fprs < expect_fpr
    fprs = fprs[idxes]
    fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
    pro_auc = auc(fprs, pros[idxes])
    return pro_auc

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train(args, device):
    # configs
    image_size = args.image_size
    save_path = args.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    txt_path = os.path.join(save_path, 'log.txt')  # log

    # model configs
    with open(args.config_path, 'r') as f:
        model_configs = json.load(f)
    # teacher_clip model

    FuzzyCLIP_parameters = {"Prompt_length": 12, "learnabel_text_embedding_depth":9 , "learnabel_text_embedding_length": 36}
    model, _, preprocess = Fuzzy_clip.create_model_and_transforms(args.model, image_size, pretrained=args.pretrained,design_details=FuzzyCLIP_parameters)
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor()
    ])
    
    # datasets
    if args.dataset == 'mvtec':
        test_data =  VisaDataset(root=args.test_data_path, transform=preprocess, target_transform=transform)
    else:
        test_data =  MVTecDataset(root=args.test_data_path, transform=preprocess, target_transform=transform,
                                  aug_rate=0.0)
    test_dataloader = torch.utils.data.DataLoader(test_data,batch_size=1, shuffle=False)    
    test_obj_list = test_data.get_cls_names()
    test_categories_str = ', '.join(test_obj_list)
    prompt_learner = FuzzyCLIP_PromptLearner(model.to("cpu"), FuzzyCLIP_parameters,test_categories_str)
    checkpoint = torch.load(args.checkpoint_path,map_location=torch.device(device))

    prompt_learner.load_state_dict(checkpoint["prompt_learner"])
    prompt_learner.to(device)
    model.to(device)
    ##########################################################################################
    prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
    text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
    text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
    text_features = text_features/text_features.norm(dim=-1, keepdim=True)

    # logger
    root_logger = logging.getLogger()
    for handler in root_logger.handlers[:]:
        root_logger.removeHandler(handler)
    root_logger.setLevel(logging.WARNING)
    logger = logging.getLogger('train')
    formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
                                  datefmt='%y-%m-%d %H:%M:%S')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(txt_path, mode='w')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # record parameters
    for arg in vars(args):
        logger.info(f'{arg}: {getattr(args, arg)}')

    # transforms




    score = test(args, model,text_features,test_obj_list,test_dataloader,logger,device)

def test(args, model,text_features,obj_list, test_loader,logger,device):
    results = {}
    results['cls_names'] = []
    results['imgs_masks'] = []
    results['anomaly_maps'] = []
    results['gt_sp'] = []
    results['pr_sp'] = []
    features_list=args.features_list
    for items in tqdm(test_loader, desc="Processing test dataloader"):
        image = items['img'].to(device)
        cls_name = items['cls_name']
        results['cls_names'].append(cls_name[0])
        gt_mask = items['img_mask']
        gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
        results['imgs_masks'].append(gt_mask)
        results['gt_sp'].append(items['anomaly'].item())

        
        image_features, student_patch_features = model.encode_image(image, features_list)
        text_probs = image_features.unsqueeze(1) @ text_features.permute(0, 2, 1)
        text_probs = (text_probs/0.07).softmax(-1)
        text_probs = text_probs[:, 0, 1]
        image_score= text_probs.detach().cpu()
        results['pr_sp'].append(image_score)
    table_ls = []
    auroc_sp_ls = []
    f1_sp_ls = []
    ap_sp_ls = []

    for obj in obj_list:
        table = []
        gt_sp = []
        pr_sp = []
        table.append(obj)
        for idxes in range(len(results['cls_names'])):
            if results['cls_names'][idxes] == obj:
                gt_sp.append(results['gt_sp'][idxes])
                pr_sp.append(results['pr_sp'][idxes])
        
        gt_sp = np.array(gt_sp)
        pr_sp = np.array(pr_sp)

        auroc_sp = roc_auc_score(gt_sp, pr_sp)
        ap_sp = average_precision_score(gt_sp, pr_sp)
        precisions, recalls, thresholds = precision_recall_curve(gt_sp, pr_sp)
        f1_scores = (2 * precisions * recalls) / (precisions + recalls)
        f1_sp = np.max(f1_scores[np.isfinite(f1_scores)])

        table.append(str(np.round(auroc_sp * 100, decimals=1)))
        table.append(str(np.round(f1_sp * 100, decimals=1)))
        table.append(str(np.round(ap_sp * 100, decimals=1)))

        table_ls.append(table)
        auroc_sp_ls.append(auroc_sp)
        f1_sp_ls.append(f1_sp)
        ap_sp_ls.append(ap_sp)
    if True:
        table_ls.append(['mean', str(np.round(np.mean(auroc_sp_ls) * 100, decimals=1)),
                        str(np.round(np.mean(f1_sp_ls) * 100, decimals=1)), str(np.round(np.mean(ap_sp_ls) * 100, decimals=1))])

        results = tabulate(table_ls, headers=['objects', 'auroc_sp', 'f1_sp', 'ap_sp'], tablefmt="pipe")


        logger.info("\n%s", results)
        return np.mean(auroc_sp_ls)+np.mean(f1_sp_ls)+np.mean(ap_sp_ls)
    else:
        return 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser("FuzzyCLIP", add_help=True)
    # path
    parser.add_argument("--test_data_path", type=str, default="/home/hyn/workspace/CLIP/VAND-APRIL-GAN-master/data/visa", help="test dataset path")
    parser.add_argument("--save_path", type=str, default='./exps/visa/vit_large_14_518', help='path to save results')
    parser.add_argument("--config_path", type=str, default='./open_clip/model_configs/ViT-L-14-336.json', help="model configs")
    # model
    parser.add_argument("--dataset", type=str, default='mvtec', help="train dataset name")
    parser.add_argument("--model", type=str, default="ViT-L-14-336", help="model used")
    parser.add_argument("--pretrained", type=str, default="openai", help="pretrained weight used")
    parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
    # hyper-parameter
    parser.add_argument("--image_size", type=int, default=518, help="image size")
    parser.add_argument("--print_freq", type=int, default=1, help="print frequency")
    parser.add_argument("--save_freq", type=int, default=1, help="save frequency")
    parser.add_argument("--checkpoint_path", type=str, default='/home/hyn/workspace/AIAD/VAND-APRIL-GAN/prompt_model/epoch_5.pth', help='path to save results')
    parser.add_argument("--device", type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu', help="device to use for training and testing")
    args = parser.parse_args()

    setup_seed(111)
    train(args, args.device)
