import numpy as np
import torch


# keep top k largest values, and smooth others
def keep_top_k(p,k,n_classes=1000): # p is the softmax on label output
    if k == n_classes:
        return p

    values, indices = p.topk(k, dim=1)

    mask_topk = torch.zeros_like(p)
    mask_topk.scatter_(-1, indices, 1.0)
    top_p = mask_topk * p

    minor_value = (1 - torch.sum(values, dim=1)) / (n_classes-k)
    minor_value = minor_value.unsqueeze(1).expand(p.shape)
    mask_smooth = torch.ones_like(p)
    mask_smooth.scatter_(-1, indices, 0)
    smooth_p = mask_smooth * minor_value

    topk_smooth_p = top_p + smooth_p
    assert np.isclose(topk_smooth_p.sum().item(), p.shape[0]), f'{topk_smooth_p.sum().item()} not close to {p.shape[0]}'
    return topk_smooth_p


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0
        self.val = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def get_parameters(model):
    group_no_weight_decay = []
    group_weight_decay = []
    for pname, p in model.named_parameters():
        if pname.find('weight') >= 0 and len(p.size()) > 1:
            # print('include ', pname, p.size())
            group_weight_decay.append(p)
        else:
            # print('not include ', pname, p.size())
            group_no_weight_decay.append(p)
    assert len(list(model.parameters())) == len(
        group_weight_decay) + len(group_no_weight_decay)
    groups = [dict(params=group_weight_decay), dict(
        params=group_no_weight_decay, weight_decay=0.)]
    return groups

def compute_adjustment_auto_1(label_freq_array, tau):
    return np.log(label_freq_array**tau + 1e-12)
        
def compute_adjustment_auto(model_teacher, original_label_freq_array, train_loader, device, class_number):
    if type(model_teacher) is not list:
        model_teacher = [model_teacher,]
    # original_label_freq_array = trainset_imb.get_cls_num_list()
    # original_label_freq_array = np.array(original_label_freq_array)
    # original_label_freq_array = original_label_freq_array / original_label_freq_array.sum()

    logit = []
    label = []
    for model in model_teacher:
        model.eval()
        with torch.no_grad():
            for epochs in range(5):
                for i, tmp in enumerate(train_loader):
                    inputs, targets = tmp[:2]
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model(inputs)
                    logit.append(outputs.cpu().numpy())
                    label.append(targets.cpu().numpy())
    logit = np.concatenate(logit, axis=0)
    label = np.concatenate(label, axis=0)
        
    tau_candidate = np.linspace(0, 3, 30)
    min_std = np.inf
    std_list = []
    best_tau = None
    for tau in tau_candidate:
        adjustment = compute_adjustment_auto_1(original_label_freq_array, tau)
        logit_ = logit - adjustment
        softmax_ = np.exp(logit_) / np.exp(logit_).sum(axis=1, keepdims=True)
        confs = []
        for i in range(class_number):
            confs.append(softmax_[label == i][:, i].mean())
        confs_std = np.std(confs)
        std_list.append(confs_std)
        if confs_std < min_std:
            min_std = confs_std
            best_tau = tau
        print(f'best tau: {best_tau}, min std: {min_std}, class_number: {class_number}')
            
    print(f'best tau: {best_tau}, min std: {min_std}')
    adjustment = compute_adjustment_auto_1(original_label_freq_array, best_tau)
    return torch.tensor(adjustment).to(device)



def compute_adjustment(original_label_freq_array, device, tau):
    adjustment = compute_adjustment_auto_1(original_label_freq_array, tau)
    return torch.tensor(adjustment).to(device)


