import os
import argparse
import numpy as np
import torch
from scipy import stats
from transformers import CLIPTokenizer

from utils.common import setup_seed, get_num_cls, get_test_labels
from utils.detection_util import get_Mahalanobis_score, get_mean_prec, print_measures, get_and_print_results, \
    get_ood_scores_clip
from utils.file_ops import save_as_dataframe, setup_log
from utils.plot_util import plot_distribution
from utils.train_eval_util import set_model_clip, set_train_loader, set_val_loader, set_ood_loader_ImageNet
# sys.path.append(os.path.dirname(__file__))


import warnings

warnings.filterwarnings("ignore")


def process_args():
    parser = argparse.ArgumentParser(description='Evaluates MCM Score for CLIP',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # setting for each run
    parser.add_argument('--in_dataset', default='ImageNet', type=str,
                        choices=['ImageNet', 'ImageNet10', 'ImageNet20', 'ImageNet100',
                                 'pet37', 'food101', 'car196', 'bird200'], help='in-distribution dataset')
    parser.add_argument('--root-dir', default="datasets", type=str,
                        help='root dir of datasets')
    parser.add_argument('--name', default="eval_ood",
                        type=str, help="unique ID for the run")
    parser.add_argument('--seed', default=42, type=int, help="random seed")
    parser.add_argument('--gpu', default=3, type=int,
                        help='the GPU indice to use')

    parser.add_argument('--model', default='CLIP', type=str, help='model architecture')
    parser.add_argument('--CLIP_ckpt', type=str, default='ViT-B/16',
                        choices=['ViT-B/32', 'ViT-B/16', 'ViT-L/14'], help='which pretrained img encoder to use')
    parser.add_argument('--score', default='ours', type=str, choices=[
        'MCM', 'energy', 'max-logit', 'entropy', 'var', 'maha', 'ours'], help='score options')

    parser.add_argument(
        "--mc_global_number",
        type=int,
        default=1,
        help="Number of random global crops.",
    )
    parser.add_argument(
        "--mc_global_scale",
        type=float,
        nargs="+",
        default=(0.5, 1.0),
        help="Scale range for global crops.",
    )

    parser.add_argument('-b', '--batch-size', default=512, type=int,
                        help='mini-batch size')

    parser.add_argument('--ngroup', type=int, default=10,
                        help='number of grouping')

    # hyper-parameters

    parser.add_argument('--alpha', type=float, default=0.3,
                        help='temperature parameter')

    parser.add_argument('--beta', type=float, default=0.8,
                        help='temperature parameter')

    parser.add_argument('--t1', type=float, default=0.08,
                        help='temperature parameter')
    parser.add_argument('--t2', type=float, default=0.02,
                        help='temperature parameter')
    parser.add_argument('--t3', type=float, default=50,
                        help='temperature parameter')
    parser.add_argument(
        "--mc_local_number",
        type=int,
        default=2,
        help="Number of random local crops.",
    )
    parser.add_argument(
        "--mc_local_scale",
        type=float,
        nargs="+",
        default=(0.55, 1),
        help="Scale range for local crops.",
    )

    # end
    # Print the argument values
    args = parser.parse_args()
    print(f"mc_global_number: {args.mc_global_number}")
    print(f"mc_global_scale: {args.mc_global_scale}")
    print(f"batch-size 固定: {args.batch_size}")
    print(f"ngroup: {args.ngroup}")
    print(f"alpha: {args.alpha}")
    print(f"beta: {args.beta}")
    print(f"t1: {args.t1}")
    print(f"t2: {args.t2}")
    print(f"t3: {args.t3}")
    print(f"mc_local_number: {args.mc_local_number}")
    print(f"mc_local_scale: {args.mc_local_scale}")
    # for Mahalanobis score
    parser.add_argument('--feat_dim', type=int, default=512, help='feat dim； 512 for ViT-B and 768 for ViT-L')
    parser.add_argument('--normalize', type=bool, default=False, help='whether use normalized features for Maha score')
    parser.add_argument('--generate', type=bool, default=True,
                        help='whether to generate class-wise means or read from files for Maha score')
    parser.add_argument('--template_dir', type=str, default='img_templates',
                        help='the loc of stored classwise mean and precision matrix')
    parser.add_argument('--subset', default=False, type=bool,
                        help="whether uses a subset of samples in the training set")
    parser.add_argument('--max_count', default=250, type=int,
                        help="how many samples are used to estimate classwise mean and precision matrix")
    args = parser.parse_args()

    args.n_cls = get_num_cls(args)
    args.log_directory = f"results/{args.in_dataset}/{args.score}/{args.model}_{args.CLIP_ckpt}_ID_{args.name}"
    os.makedirs(args.log_directory, exist_ok=True)

    return args


def main():
    args = process_args()
    setup_seed(args.seed)
    log = setup_log(args)
    assert torch.cuda.is_available()
    torch.cuda.set_device(args.gpu)

    net, preprocess = set_model_clip(args)
    net.eval()

    if args.in_dataset in ['ImageNet10']:
        out_datasets = ['ImageNet20']
    elif args.in_dataset in ['ImageNet20']:
        out_datasets = ['ImageNet10']
    elif args.in_dataset in ['ImageNet', 'ImageNet100', 'bird200', 'car196', 'food101', 'pet37']:
        out_datasets = ['iNaturalist', 'SUN', 'places365', 'dtd']
    test_loader = set_val_loader(args, preprocess)
    test_labels = get_test_labels(args, test_loader)
    # print(test_labels)

    #################################
    # if not os.path.exists('./selected_neg_labels/neg_dump.pth'):
    #     from transformers import CLIPTokenizer
    #     tokenizer = CLIPTokenizer.from_pretrained("/hdd/qt/MCM_V41/clip-vit-base-patch16")
    #     #tokenizer = CLIPTokenizer.from_pretrained(args.ckpt)
    #     emb_batchsize = 1000
    #
    #     with torch.no_grad():
    #
    #         prompts_neg = []
    #         with open('./selected_neg_labels/selected_neg_labels_in1k_10k.txt', 'r') as file:
    #             for line in file:
    #                 prompts_neg.append(line.strip())
    #
    #             # text_inputs_neg = torch.cat([clip.tokenize(f"{c}") for c in prompts_neg]).to(self.device)
    #
    #             text_features_neg = []
    #             for i in range(0, len(prompts_neg), emb_batchsize):
    #                 text_inputs_neg_i = tokenizer([f"{c}" for c in prompts_neg[i: i + emb_batchsize]], padding=True,
    #                                               return_tensors="pt")
    #
    #                 # text_inputs_neg = torch.cat([tokenizer(f"{c}") for c in prompts_neg[i : i + emb_batchsize], padding=True, return_tensors="pt"]).cuda()
    #
    #                 x = net.get_text_features(input_ids=text_inputs_neg_i['input_ids'].cuda(),
    #                                           attention_mask=text_inputs_neg_i['attention_mask'].cuda()).float()
    #
    #                 text_features_neg.append(x)
    #             text_features_neg = torch.cat(text_features_neg, dim=0)
    #             text_features_neg /= text_features_neg.norm(dim=-1, keepdim=True)
    #         dump_dict = dict(neg_emb=text_features_neg.cpu())
    #         torch.save(dump_dict, './selected_neg_labels/neg_dump.pth')
    # else:
    #     dump_dict = torch.load('./selected_neg_labels/neg_dump.pth')
    #     text_features_neg = dump_dict['neg_emb'].cuda()
    ###################################

    # ——————————————————————————————————————————————————
    dump_dict = torch.load('./selected_neg_labels/neg_dump_new2.pth')
    # dump_dict = torch.load('./selected_neg_labels/neg_dump_new2.pth')
    text_features_neg = dump_dict['neg_emb_selected'].cuda().to(torch.float32)
    # ——————————————————————————————————————————————————


    # text_features_neg = text_features_neg[0:10000]
    print(text_features_neg.size())

    drop = text_features_neg.shape[0] % args.ngroup

    random_permute = True

    if drop > 0:
        text_features_neg = text_features_neg[:-drop, :]
    if random_permute:
        idx = torch.randperm(text_features_neg.size(0)).cuda()
        text_features_neg = text_features_neg[idx]
    text_features_neg = torch.reshape(text_features_neg, (args.ngroup, -1, text_features_neg.size(1)))

    ###################################
    # from transformers import CLIPTokenizer
    # tokenizer = CLIPTokenizer.from_pretrained("/hdd/qt/MCM_V41/clip-vit-base-patch16")
    # text_inputs_pos = tokenizer([f"the nice {c}" for c in test_labels], padding=True, return_tensors="pt")
    # text_features_pos = net.get_text_features(input_ids=text_inputs_pos['input_ids'].cuda(),
    #                                           attention_mask=text_inputs_pos['attention_mask'].cuda()).float()
    # text_features_pos /= text_features_pos.norm(dim=-1, keepdim=True)
    ###################################

    # ——————————————————————————————————————————————————
    text_features_pos = dump_dict['pos_emb'].cuda().to(torch.float32)
    # ——————————————————————————————————————————————————

    if args.score == 'maha':
        os.makedirs(args.template_dir, exist_ok=True)
        train_loader = set_train_loader(args, preprocess, subset=args.subset)
        if args.generate:
            classwise_mean, precision = get_mean_prec(args, net, train_loader)
        classwise_mean = torch.load(os.path.join(args.template_dir,
                                                 f'{args.model}_classwise_mean_{args.in_dataset}_{args.max_count}_{args.normalize}.pt'),
                                    map_location='cpu').cuda()
        precision = torch.load(os.path.join(args.template_dir,
                                            f'{args.model}_precision_{args.in_dataset}_{args.max_count}_{args.normalize}.pt'),
                               map_location='cpu').cuda()
        in_score = get_Mahalanobis_score(args, net, test_loader, classwise_mean, precision, in_dist=True)
    else:
        #tokenizer = CLIPTokenizer.from_pretrained(args.ckpt)

        in_score = get_ood_scores_clip(args, net, test_loader, test_labels, text_features_neg, text_features_pos=text_features_pos,in_dist=True)

    auroc_list, aupr_list, fpr_list = [], [], []
    for out_dataset in out_datasets:
        log.debug(f"Evaluting OOD dataset {out_dataset}")
        ood_loader = set_ood_loader_ImageNet(args, out_dataset, preprocess,
                                             root=os.path.join(args.root_dir, 'ImageNet_OOD_dataset'))
        if args.score == 'maha':
            out_score = get_Mahalanobis_score(args, net, ood_loader, classwise_mean, precision, in_dist=False)
        else:
            out_score = get_ood_scores_clip(args, net, ood_loader, test_labels, text_features_neg, text_features_pos=text_features_pos)
        # log.debug(f"in scores: {stats.describe(in_score)}")
        # log.debug(f"out scores: {stats.describe(out_score)}")
        plot_distribution(args, in_score, out_score, out_dataset)
        get_and_print_results(args, log, in_score, out_score,
                              auroc_list, aupr_list, fpr_list)
    log.debug('\n\nMean Test Results')
    print_measures(log, np.mean(auroc_list), np.mean(aupr_list),
                   np.mean(fpr_list), method_name=args.score)
    save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list)


if __name__ == '__main__':
    main()
