# -*- coding:utf-8 -*- 
import logging
import os
import json
import torch.nn.functional as F
import torch
import numpy as np
from torch.utils.data import DataLoader, RandomSampler

logger = logging.getLogger(__name__)

def debias(logits, bias, tau=0.4):
    # debias pseudo labels
    debiased_prob = F.softmax(logits - tau*torch.log(bias), dim=1)
    return debiased_prob

def debias_output(logits, bias, tau=0.4):
    # add bias logits
    debiased_prob = F.softmax(logits + tau*torch.log(bias), dim=-1)
    return debiased_prob

def bias_initial(train_dataloader):
    # compute base probs
    label_freq = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}
    for step, batch in enumerate(train_dataloader):
        valid_pos = batch[2]
        target = batch[3][valid_pos>0]
        for j in target:
            key = int(j.item())-1
            if key in label_freq:
                label_freq[key] = label_freq[key] + 1
    label_freq_array = np.array(list(label_freq.values()))
    label_freq_array = label_freq_array / label_freq_array.sum()

    base_prob = (torch.from_numpy(label_freq_array).float()).cuda()
    return base_prob

def bin_bias_initial(train_dataloader):
    # compute base probs
    label_freq = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0}
    for step, batch in enumerate(train_dataloader):
        valid_pos = batch[2]
        target = batch[3][valid_pos>0]
        for j in target:
            key = int(j.item())
            label_freq[key] = label_freq.get(key, 0) + 1
    bin_label_freq = {}
    bin_label_freq[0] = label_freq[0]
    bin_label_freq[1] = sum(label_freq.values()) - label_freq[0]
    del label_freq[0]

    label_freq_array = np.array(list(label_freq.values()))
    label_freq_array = label_freq_array / label_freq_array.sum()
    base_prob = (torch.from_numpy(label_freq_array).float()).cuda()

    bin_label_freq_array = np.array(list(bin_label_freq.values()))
    bin_label_freq_array = bin_label_freq_array / bin_label_freq_array.sum()
    bin_base_prob = (torch.from_numpy(bin_label_freq_array).float()).cuda()
    return base_prob, bin_base_prob

def bias_update(input, bias, momentum=0.99, bias_mask=None):
    if bias_mask is not None:
        input_mean = input.detach()*bias_mask.detach().unsqueeze(dim=-1)
    else:
        input_mean = input.detach().mean(dim=0)
    bias = momentum * bias + (1 - momentum) * input_mean
    return bias

def soft_frequency(logits, power=2, probs=False):
    """
    Unsupervised Deep Embedding for Clustering Analysiszaodian
    https://arxiv.org/abs/1511.06335
    """
    if not probs:
        softmax = torch.nn.Softmax(dim=1)
        y = softmax(logits.view(-1, logits.shape[-1])).view(logits.shape)
    else:
        y = logits
    f = torch.sum(y, dim=0)
    t = y**power / f
    p = t/torch.sum(t, dim=-1, keepdim=True)
    # m = torch.argmax(y, dim=2, keepdim=True)
    # m = (m==0)
    # m = m.repeat(1,1,y.size(2))
    # p = p.masked_fill(mask=m,value=torch.tensor(0))
    # m = ~m
    # y = y.masked_fill(mask=m,value=torch.tensor(0))
    # p = p+y

    return p

def get_hard_label(args, combined_labels, pred_labels, pad_token_label_id, pred_logits=None):
    pred_labels[combined_labels==pad_token_label_id] = pad_token_label_id

    return pred_labels, None

def mask_tokens(args, combined_labels, pred_labels, pad_token_label_id, pred_logits=None):

    if args.self_learning_label_mode == "hard":
        softmax = torch.nn.Softmax(dim=1)
        y = softmax(pred_logits.view(-1, pred_logits.shape[-1])).view(pred_logits.shape)
        _threshold = args.threshold
        pred_labels[y.max(dim=-1)[0]>_threshold] = pad_token_label_id
        return pred_labels, None

    elif args.self_learning_label_mode == "soft":
        label_mask = (pred_labels.max(dim=-1)[0]>args.threshold)
        return pred_labels, label_mask

def mask_bitokens(args, type_logits, bin_logits, labels):

    if args.self_learning_label_mode == "soft":
        type_label_mask = (type_logits.max(dim=-1)[0]>args.threshold)
        entity_prob = torch.sigmoid(bin_logits)
        bin_pred = torch.cat((1-entity_prob, entity_prob), dim=-1)
        bin_labels = labels.clone()
        bin_labels[labels > 0] = 1
        bin_labels = bin_labels[labels != -100]
        bin_label_mask = (bin_pred.max(dim=-1)[0] > args.threshold) | (bin_labels == 1)
        return type_logits, type_label_mask, bin_label_mask

def mask_bitokens2(args, type_logits, bin_logits, labels):

    if args.self_learning_label_mode == "soft":
        type_label_mask = (type_logits.max(dim=-1)[0]>args.threshold)
        bin_pred = torch.softmax(bin_logits, dim=-1)
        bin_labels = labels.clone()
        bin_labels[labels > 0] = 1
        bin_labels = bin_labels[labels != -100]
        bin_label_mask = (bin_pred.max(dim=-1)[0] > args.threshold) | (bin_labels == 1)
        return type_logits, type_label_mask, bin_label_mask

def opt_grad(loss, in_var, optimizer):
    if hasattr(optimizer, 'scalar'):
        loss = loss * optimizer.scaler.loss_scale
    return torch.autograd.grad(loss, in_var)

def _update_mean_model_variables(model, m_model, alpha, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for m_param, param in zip(m_model.parameters(), model.parameters()):
        m_param.data.mul_(alpha).add_(1 - alpha, param.data)

def _update_mean_prediction_variables(prediction, m_prediction, alpha, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    # for m_param, param in zip(m_model.parameters(), model.parameters()):
    m_prediction.data.mul_(alpha).add_(1 - alpha, prediction.data)
