import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
import faiss

def sinkhorn(pred, eta, r_in=None, rec=False):
    PS = pred.detach()
    K = PS.shape[1]
    N = PS.shape[0]
    PS = PS.T  # K x N
    c = torch.ones((N, 1), device=pred.device) / N
    r = r_in.to(pred.device)
    
    PS = torch.pow(PS + 1e-8, eta)  # K x N

    r_init = copy.deepcopy(r)
    inv_N = 1.0 / N
    err = 1e6
    _counter = 1

    for i in range(50):
        if err < 1e-1:
            break

        r = r_init * (1.0 / (PS @ c + 1e-8))  # K x 1
        c_new = inv_N / ((r.T @ PS).T + 1e-8)  # N x 1

        if _counter % 10 == 0:
            err = torch.sum(c_new) + torch.sum(r)
            if torch.isnan(err):
                print('====> NaN detected in Sinkhorn, applying fallback solution')

                if rec:
                    relaxed_PS = torch.ones_like(pred) / K
                    return relaxed_PS.detach(), True
                else:
                    pred_new = pred + 1e-5 * (pred == 0)
                    relaxed_PS, _ = sinkhorn(pred_new, eta, r_in=r_in, rec=True)
                    z = (1.0 * (pred != 0))
                    relaxed_PS = relaxed_PS * z
                    return relaxed_PS, True

        c = c_new
        _counter += 1

    PS *= torch.squeeze(c)
    PS = PS.T
    PS *= torch.squeeze(r)
    PS *= N
    return PS.detach(), False

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def adjust_learning_rate(args, optimizer, epoch):
    lr = args.lr
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape((-1, )).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class AccurracyShot(object):
    def __init__(self, train_class_count, num_class, many_shot_num=3, low_shot_num=3):
        self.train_class_count = train_class_count
        self.test_class_count = None

        if len(train_class_count) != num_class:
            num_class = len(train_class_count)

        many_shot_num = min(max(1, many_shot_num), num_class - 1)
        low_shot_num = min(max(1, low_shot_num), num_class - 1)

        sorted_counts = train_class_count.sort()[0]
        self.num_class = num_class
        self.many_shot_thr = sorted_counts[max(0, num_class - many_shot_num - 1)]
        self.low_shot_thr = sorted_counts[min(low_shot_num, num_class - 1)]

    def get_shot_acc(self, preds, labels, acc_per_cls=False):
        if self.test_class_count is None:
            self.test_class_count = []
            for l in range(self.num_class):
                self.test_class_count.append(len(labels[labels == l]))
        
        class_correct = []
        for l in range(self.num_class):
            class_correct.append((preds[labels == l] == labels[labels == l]).sum())

        many_shot = []
        median_shot = []
        low_shot = []
        for i in range(self.num_class):
            if self.train_class_count[i] > self.many_shot_thr:
                many_shot.append((class_correct[i] / float(self.test_class_count[i])))
            elif self.train_class_count[i] < self.low_shot_thr:
                low_shot.append((class_correct[i] / float(self.test_class_count[i])))
            else:
                median_shot.append((class_correct[i] / float(self.test_class_count[i])))

        if len(many_shot) == 0:
            many_shot.append(0)
        if len(median_shot) == 0:
            median_shot.append(0)
        if len(low_shot) == 0:
            low_shot.append(0)

        if acc_per_cls:
            class_accs = [c / cnt for c, cnt in zip(class_correct, self.test_class_count)] 
            return np.mean(many_shot) * 100, np.mean(median_shot) * 100, np.mean(low_shot) * 100, class_accs
        else:
            return np.mean(many_shot) * 100, np.mean(median_shot) * 100, np.mean(low_shot) * 100

def accuracy_check(loader, model, device):
    with torch.no_grad():
        total, num_samples = 0, 0
        for images, labels in loader:
            labels, images = labels.to(device), images.to(device)
            outputs, _ = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += (predicted == labels).sum().item()
            num_samples += labels.size(0)
    return total / num_samples

def sigmoid_rampup(current, rampup_length, exp_coe=5.0):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-exp_coe * phase * phase))


def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length


def cosine_rampdown(current, rampdown_length):
    """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
    assert 0 <= current <= rampdown_length
    return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))

@torch.no_grad()
def faiss_knn(features, k):
    """
    features: Tensor [N, D] (CPU or CUDA)
    return: indices [N, k]
    """
    if features.is_cuda:
        features = features.cpu()
    features = F.normalize(features, dim=1).numpy()

    dim = features.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(features)
    _, idx = index.search(features, k)
    return torch.from_numpy(idx)
