import os
import argparse

import numpy as np
from scipy.special import softmax, logsumexp
import torch
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
import torchvision.transforms as transforms

import faiss
from tqdm import tqdm

import models.hyptorch as hypb
from models.ash import get_score


def get_dist_score(args, net, test_loader, 
        classwise_mean, precision, in_dist=True, score_thred=0.1, attacker=None
    ):
    '''
    Compute the proposed confidence score on input dataset
    options: Mahalanobis, poincare 
    '''
    score_all = []
    total_len = len(test_loader.dataset)

    mean = [0.491, 0.482, 0.447]
    #CIFAR[0.491, 0.482, 0.447]#IMAGENET[0.485, 0.456, 0.406],
    std = [0.247, 0.244, 0.262]
    #CIFAR[0.247, 0.244, 0.262]#IMAGENET[0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    tqdm_object = tqdm(test_loader, total=len(test_loader))
    for batch_idx, (images, labels) in enumerate(tqdm_object):
        if (batch_idx >= total_len // args.batch_size) and in_dist is False:
            break   
        if attacker:
            images = attacker(images, labels)
        images = normalize(images)

        with torch.no_grad():
            features = net.intermediate_forward(images.cuda()) 
            scores = []
            for i in range(args.n_cls):
                class_mean = classwise_mean[i]

                if args.score == "maha": 
                    zero_f = features - class_mean
                    dist = -0.5*torch.mm(
                        torch.mm(zero_f, precision), 
                        zero_f.t()
                    ).diag()
                elif args.score == "hypb":
                    dist = -hypb.dist_matrix(features, 
                        class_mean.unsqueeze(0), c=args.c_ball
                    ).T #/args.T
                scores.append(dist)

            scores = torch.vstack(scores).T
            #scores, _ = scores.min(1)
            #scores = torch.abs((scores - score_thred))
            scores = torch.max(scores, dim=1)[0]# + torch.min(scores, dim=1)[0]
            score_all.extend(-scores.cpu().numpy())
        
    return np.asarray(score_all, dtype=np.float32)


# def get_Mahalanobis_score(args, net, test_loader, 
#         classwise_mean, precision, in_dist=True, attacker=False
#     ):
#     '''
#     depreviated function. Please see get_dist_score 
#     '''
#     # net.eval()
#     Mahalanobis_score_all = []
#     total_len = len(test_loader.dataset)
#     tqdm_object = tqdm(test_loader, total=len(test_loader))
#     for batch_idx, (images, labels) in enumerate(tqdm_object):
#         if (batch_idx >= total_len // args.batch_size) and in_dist is False:
#             break   
#         with torch.no_grad():
#             features = net.intermediate_forward(images.cuda()) 
#             #features = net.head_forward(images.cuda()) 
# 
#             for i in range(args.n_cls):
#                 class_mean = classwise_mean[i]
#                 zero_f = features - class_mean
#                 Mahalanobis_dist = -0.5*torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag()
#                 if i == 0:
#                     Mahalanobis_score = Mahalanobis_dist.view(-1,1)
#                 else:
#                     Mahalanobis_score = torch.cat((Mahalanobis_score, Mahalanobis_dist.view(-1,1)), 1)      
#             Mahalanobis_score, _ = torch.max(Mahalanobis_score, dim=1)
#             Mahalanobis_score_all.extend(-Mahalanobis_score.cpu().numpy())
#         
#     return np.asarray(Mahalanobis_score_all, dtype=np.float32)


class KNN:
    def __init__(self, args, ftrain, metric="L2"):
        self.args = args
        # if args.head in ["hypb", "manifold"]:
        #     ftrain = torch.tensor(ftrain)
        #     ftrain = hypb.expmap0(ftrain, c=args.c_ball).cpu().numpy()
        if metric == "inner": 
            index = faiss.IndexFlatIP(ftrain.shape[1])
        else:
            index = faiss.IndexFlatL2(ftrain.shape[1])
        #print(ftrain.shape)
        index.add(ftrain)
        self.index_bad = index
        self.K = args.K # len(ftrain)# args.K
        self.cutoff = 0

    def get_knn_score(self, ftest, ind=-1):
        # if self.args.head in ["hypb", "manifold"]:
        #     ftest = torch.tensor(ftest)
        #     ftest = hypb.expmap0(ftest, c=self.args.c_ball).cpu().numpy()
        # import matplotlib.pylab as plt
        D, _ = self.index_bad.search(ftest, self.K)
        # plt.hist(D[3])
        # plt.savefig("out_{}.png".format(np.random.randn()))

        score = D[:,ind]
        return score

    def get_knn_weighted_score(self, ftest, logits):
        D, _ = self.index_bad.search(ftest, self.args.n_cls)
        if D.shape != logits.shape:
            return D[:,-1]
        logits = softmax(logits, 1)
        if not self.cutoff:
            self.cutoff = np.percentile(logits[np.random.choice(len(logits), 10)], 60)#, axis=1, keepdims=True)
        weighted_D = D*logits*(logits<self.cutoff)
        score = np.sum(weighted_D, axis=1)
        return score 

    def get_neg_score(self, ftest, logits, neg_sample):
        n_score = pairwise_distances(ftest, neg_sample, metric="l2")
        return np.squeeze(n_score)

    def get_train_dist(self, ftrain, logits, prototypes):
        pred = np.argmax(logits)
        #class_max_score = [train_score[] for p in classwise_mean] 


def get_pyood_score(args, detector, test_loader):
    '''
    Compute the proposed pyood confidence score on input dataset
    '''
    # net.eval()
    total_len = len(test_loader.dataset)
    scores = []
    tqdm_object = tqdm(test_loader, total=len(test_loader))
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(tqdm_object):
            if (batch_idx >= total_len // args.batch_size):
                break   
            score = detector(images.cuda())
            scores.extend(score.cpu().numpy())

    return np.asarray(scores, dtype=np.float32)


def get_ash_score(args, net, test_loader):
    total_len = len(test_loader.dataset)
    scores = []
    tqdm_object = tqdm(test_loader, total=len(test_loader))
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(tqdm_object):
            if (batch_idx >= total_len // args.batch_size):
                break   
            #features = net.intermediate_forward(images.cuda()) 
            logits = net(images.cuda()) 
            score = get_score(logits, method="msp")
            scores.extend(score)#.cpu().numpy())

    return np.asarray(scores, dtype=np.float32)


def get_mean_prec(args, net, train_loader, attacker=None):
    '''
    used for Mahalanobis score. Calculate class-wise mean and inverse covariance matrix
    '''
    save_dir = os.path.join(
        args.main_dir, 'feat', f"{args.in_dataset}", f"{args.name}", 'maha'
    )
    mean_loc = os.path.join(
        save_dir, f'{args.loss}_{args.score}_classwise_mean.pt'
    )
    prec_loc = os.path.join(save_dir,  f'{args.loss}_precision.pt')
    normalize = transforms.Normalize(
        mean=[0.491, 0.482, 0.447],
        std=[0.247, 0.244, 0.262]
    )
    os.makedirs(save_dir, exist_ok=True)
    if os.path.exists(mean_loc) and os.path.exists(mean_loc) and False:
        classwise_mean = torch.load(mean_loc, map_location= 'cpu').cuda()
        precision = torch.load(prec_loc, map_location= 'cpu').cuda()
    else: 
        classwise_mean = torch.empty(args.n_cls, args.embedding_dim,  device='cuda')
        all_features = torch.zeros((0, args.embedding_dim), device='cuda')
        classwise_idx = {} 
        for idx, (image, labels) in enumerate(tqdm(train_loader)):
            if attacker:
                image = attacker(image, labels)
            image = normalize(image)

            with torch.no_grad():
                out_feature = net.intermediate_forward(image.cuda()) 
            # if args.head == "manifold": 
            #     out_feature1 = net.head_forward(image.cuda(), multi=(args.head == "manifold")) 
            #     #out_feature1 = out_feature1[0]*out_feature1[1]
            #     out_feature1 = torch.cat(out_feature1, dim=1)
            #     out_feature = torch.cat((out_feature, out_feature1), dim=1)
            #out_feature = net.head_forward(image.cuda()) 

            all_features = torch.cat((all_features,out_feature), dim = 0)
        
        targets = np.array(train_loader.dataset.targets) 
        for class_id in range(args.n_cls):
            classwise_idx[class_id] = np.where(targets == class_id)[0]
        
        if args.head == "hypb":
            for cls in range(args.n_cls):
                classwise_mean[cls] = hypb.poincare_mean(
                    all_features[classwise_idx[cls]].float(), 
                    dim = 0, 
                    c=args.c_ball
                )
        else:
            for cls in range(args.n_cls):
                classwise_mean[cls] = torch.mean(
                    all_features[classwise_idx[cls]].float(), dim=0
                )
            
        cov = torch.cov(all_features.T.double()) 
        precision = torch.linalg.pinv(cov).float()
        print(f'cond number: {torch.linalg.cond(precision)}')
        torch.save(classwise_mean, mean_loc)
        torch.save(precision, prec_loc)
    return classwise_mean, precision


def get_grad_score(args, model, data_loader, percentile=30):
    assert 0 <= percentile <= 100
    confs = []
    total_len = len(data_loader.dataset)
    logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda()
    tqdm_object = tqdm(data_loader, total=len(data_loader))
    for b, (images, labels) in enumerate(tqdm_object):
        all_grad = []
        if (b >= total_len // args.batch_size):
            break   
        for i in range(args.batch_size):
            x, y = images[i].unsqueeze(0), labels[i]
            inputs = torch.autograd.Variable(x.cuda(), requires_grad=True)

            model.zero_grad()
            outputs = model(inputs)
            targets = torch.ones((inputs.shape[0], args.feat_dim)).cuda()

            outputs = outputs / args.T
            loss = torch.mean(torch.sum(-targets * logsoftmax(outputs), dim=-1))

            loss.backward()
            #print(model)
            layer_grad = model.head[0].weight.grad.data # (D, N)
            all_grad.append(layer_grad)

            layer_grad_norm = torch.sum(torch.abs(layer_grad)).cpu().numpy()
            #confs.append(layer_grad_norm)   
        all_grad = torch.abs(torch.stack(all_grad).unsqueeze(1))
        all_grad = ash_s(all_grad, percentile)
        confs.append(all_grad.mean(dim=[1, 2, 3]).cpu().numpy())
        #print(all_grad.shape, confs[-1].shape)

    return np.array(np.ravel(confs), dtype=np.float32)


def ash_s(x, percentile=65):
    b = x.shape[0]
    d = len(x.shape)

    # calculate the sum of the input per sample
    s1 = x.sum(dim=list(range(1, d)))
    n = x.shape[1:].numel()
    k = n - int(np.round(n * percentile / 100.0))
    t = x.view((b, -1))
    v, i = torch.topk(t, k, dim=1)
    t.zero_().scatter_(dim=1, index=i, src=v)

    # calculate new sum of the input per sample after pruning
    s2 = x.sum(dim=list(range(1, d)))

    # apply sharpening
    scale = s1 / s2
    x = x * torch.exp(scale[:, None])

    return x


# def ash_s(x, percentile=65):
#     b, c, h, w = x.shape
# 
#     # calculate the sum of the input per sample
#     s1 = x.sum(dim=[1, 2, 3])
#     n = x.shape[1:].numel()
#     k = n - int(np.round(n * percentile / 100.0))
#     t = x.view((b, -1))
#     v, i = torch.topk(t, k, dim=1)
#     t.zero_().scatter_(dim=1, index=i, src=v)
# 
#     # calculate new sum of the input per sample after pruning
#     s2 = x.sum(dim=[1, 2, 3])
# 
#     # apply sharpening
#     scale = s1 / s2
#     x = x * torch.exp(scale[:, None, None, None])
# 
#     return x


