""" Additional utility functions. """
import os
import time
import pprint
import torch
import numpy as np
import torch.nn.functional as F
import scipy.stats

def ensure_path(path):
    """The function to make log path.
    Args:
      path: the generated saving path.
    """
    if os.path.exists(path):
        pass
    else:
        os.mkdir(path)

class Averager():
    """The class to calculate the average."""
    def __init__(self):
        self.n = 0
        self.v = 0

    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1

    def item(self):
        return self.v

def count_acc(logits, label):
    """The function to calculate the .
    Args:
      logits: input logits.
      label: ground truth labels.
    Return:
      The output accuracy.
    """
    pred = F.softmax(logits, dim=1).argmax(dim=1)
    if torch.cuda.is_available():
        return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    return (pred == label).type(torch.FloatTensor).mean().item()

def ensemble_acc(logits_dist, logits_sim, label, lamda):
    prob_score1 = F.softmax(logits_dist, dim=1)
    prob_score2 = F.softmax(logits_sim, dim=1)
    #prob_score = (prob_score1 + prob_score2)/2
    prob_score = lamda * prob_score1 + (1-lamda) * prob_score2
    pred = prob_score.argmax(dim=1)
    if torch.cuda.is_available():
        return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    return (pred == label).type(torch.FloatTensor).mean().item()


def emb_loss(pred_emb, real_emb, arg):
    '''
    :param pred_emb:         predicted emb (bs,300)
    :param real_emb:         real emb (bs,300)
    :param labels:           labels (bs,)
    :return:
    '''
    meanloss=torch.mean(1-torch.cosine_similarity(pred_emb, real_emb))
    #meanloss = torch.mean(torch.pow(pred_emb-real_emb, 2).sum(1))
    return meanloss

def count_acc1(pred_emb, real_emb, label):
    """The function to calculate the .
    Args:
      logits: input logits.
      label: ground truth labels.
    Return:
      The output accuracy.
    """
    x_mul = torch.matmul(pred_emb, real_emb.type(pred_emb.dtype).T)   #x_mul = torch.sum(torch.mul(pred_emb, real_emb), dim=2)
    # cosine similarity
    Normv = torch.mul(torch.norm(pred_emb, dim=1).unsqueeze(1), torch.norm(real_emb.type(pred_emb.dtype), dim=1).unsqueeze(0))
    logits = torch.div(x_mul, Normv)
    pred = F.softmax(logits, dim=1).argmax(dim=1)
    return (pred == label).type(torch.FloatTensor).mean().item(), pred

class Timer():
    """The class for timer."""
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)

_utils_pp = pprint.PrettyPrinter()

def pprint(x):
    _utils_pp.pprint(x)

def compute_confidence_interval(data):
    """The function to calculate the .
    Args:
      data: input records
      label: ground truth labels.
    Return:
      m: mean value
      pm: confidence interval.
    """
    a = 1.0 * np.array(data)
    m = np.mean(a)
    std = np.std(a)
    pm = 1.96 * (std / np.sqrt(len(a)))
    return m, pm