#!/usr/bin/env python
import argparse
import torch
import numpy as np
from tqdm import tqdm
import mmcv
from models import set_eval_model
from datasets import get_dataset_eval
from scipy.special import softmax
from sklearn import metrics
import pandas as pd


def parse_args():
    parser = argparse.ArgumentParser(description='Say hello')
    parser.add_argument('data_path', help='Path to data')
    parser.add_argument('-a', '--arch', metavar='ARCH',
                        default='resnet50')
    parser.add_argument('--pretrained_path', default='', help='Path to checkpoint')
    parser.add_argument('--img_list', default=None, help='Path to image list')
    parser.add_argument('--num_classes', type=int, default=10, help='num_classes')
    parser.add_argument('--batch_size', type=int, default=256, help='Path to data')
    parser.add_argument('--workers', type=int, default=4, help='Path to data')
    parser.add_argument('--fc_save_path', default=None, help='Path to save fc')
    parser.add_argument('--id_dataset', default='cifar10', type=str)
    parser.add_argument('--ood_dataset', type=str)
    parser.add_argument('--scratch', action='store_true',
                        help='whethere scratch or not')
    parser.add_argument('--clip_quantile', default=0.99, help='Clip quantile to react')
    return parser.parse_args()


def main():
    args = parse_args()

    args.ood_dataset = args.ood_dataset.split(',')

    id_train_dataset, id_valid_dataset = get_dataset_eval(args, args.id_dataset)

    all_test_loader = dict()

    id_train_loader = torch.utils.data.DataLoader(
        id_train_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    id_valid_loader = torch.utils.data.DataLoader(
        id_valid_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    all_test_loader[args.id_dataset+'train'] = id_train_loader
    all_test_loader[args.id_dataset+'valid'] = id_valid_loader

    torch.backends.cudnn.benchmark = True

    model = set_eval_model(args)
    model = model.cuda()

    if args.arch == 'resnet50':
        w = model.fc.weight.cpu().detach().numpy()
        b = model.fc.bias.cpu().detach().numpy()
    elif args.arch == 'vit-s':
        w = model.head.weight.cpu().detach().numpy()
        b = model.head.bias.cpu().detach().numpy()

    for ood in args.ood_dataset:
        ood_test_set = get_dataset_eval(args, dataset=ood)
        all_test_loader[ood] = torch.utils.data.DataLoader(ood_test_set, shuffle=False, batch_size=64, num_workers=args.workers, pin_memory=False)

    _num_classes = args.num_classes
    args.num_classes = 0
    model_wo_fc = set_eval_model(args)
    model_wo_fc = model_wo_fc.cuda()
    args.num_classes = _num_classes

    feature_oods = {}
    for key, dataloader in all_test_loader.items():
        features = []
        model_wo_fc.eval()
        label_list = []
        with torch.no_grad():
            for x, label in tqdm(dataloader):
                x = x.cuda()
                feat_batch = model_wo_fc(x).cpu().numpy()
                features.append(feat_batch)
                label_list.append(label.to('cpu').detach())
        label_list = torch.cat(label_list)
        features = np.concatenate(features, axis=0)
        if 'train' in key:
            feature_id_train = features
        elif 'valid' in key:
            feature_id_val = features
        else:
            feature_oods[key] = features

    print(f'{w.shape=}, {b.shape=}')

    recall = 0.95

    print('load features')

    print(f'{feature_id_train.shape=}, {feature_id_val.shape=}')
    for name, ood in feature_oods.items():
        print(f'{name} {ood.shape}')
    print('computing logits...')

    logit_id_val = feature_id_val @ w.T + b
    logit_oods = {name: feat @ w.T + b for name, feat in feature_oods.items()}

    print('computing softmax...')

    softmax_id_val = softmax(logit_id_val, axis=-1)
    softmax_oods = {name: softmax(logit, axis=-1) for name, logit in logit_oods.items()}

    df = pd.DataFrame(columns=['method', 'oodset', 'auroc', 'fpr'])

    dfs = []
    method = 'MSP'
    print(f'\n{method}')
    result = []
    score_id = softmax_id_val.max(axis=-1)
    for name, softmax_ood in softmax_oods.items():
        score_ood = softmax_ood.max(axis=-1)
        auc_ood = auc(score_id, score_ood)[0]
        fpr_ood, _ = fpr_recall(score_id, score_ood, recall)
        result.append(dict(method=method, oodset=name, auroc=auc_ood, fpr=fpr_ood))
        print(f'{method}: {name} auroc {auc_ood:.2%}, fpr {fpr_ood:.2%}')
    df = pd.DataFrame(result)
    dfs.append(df)
    print(f'mean auroc {df.auroc.mean():.2%}, {df.fpr.mean():.2%}')


def num_fp_at_recall(ind_conf, ood_conf, tpr):
    num_ind = len(ind_conf)

    if num_ind == 0 and len(ood_conf) == 0:
        return 0, 0.
    if num_ind == 0:
        return 0, np.max(ood_conf) + 1

    recall_num = int(np.floor(tpr * num_ind))
    thresh = np.sort(ind_conf)[-recall_num]
    num_fp = np.sum(ood_conf >= thresh)
    return num_fp, thresh


def fpr_recall(ind_conf, ood_conf, tpr):
    num_fp, thresh = num_fp_at_recall(ind_conf, ood_conf, tpr)
    num_ood = len(ood_conf)
    fpr = num_fp / max(1, num_ood)
    return fpr, thresh


def auc(ind_conf, ood_conf):
    conf = np.concatenate((ind_conf, ood_conf))
    ind_indicator = np.concatenate((np.ones_like(ind_conf), np.zeros_like(ood_conf)))

    fpr, tpr, _ = metrics.roc_curve(ind_indicator, conf)
    precision_in, recall_in, _ = metrics.precision_recall_curve(
        ind_indicator, conf)
    precision_out, recall_out, _ = metrics.precision_recall_curve(
        1 - ind_indicator, 1 - conf)

    auroc = metrics.auc(fpr, tpr)
    aupr_in = metrics.auc(recall_in, precision_in)
    aupr_out = metrics.auc(recall_out, precision_out)

    return auroc, aupr_in, aupr_out


if __name__ == '__main__':
    main()