""" Additional utility functions. """
import matplotlib
matplotlib.use('pdf')
import os
import time
import pprint
import torch
import numpy as np
import torch.nn.functional as F
import pulp as lp


def ensure_path(path):
    """The function to make log path.
    Args:
      path: the generated saving path.
    """
    if os.path.exists(path):
        pass
    else:
        os.mkdir(path)

class Averager():
    """The class to calculate the average."""
    def __init__(self):
        self.n = 0
        self.v = 0

    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1

    def item(self):
        return self.v

def count_acc(logits, label):
    """The function to calculate the .
    Args:
      logits: input logits.
      label: ground truth labels.
    Return:
      The output accuracy.
    """
    pred = F.softmax(logits, dim=1).argmax(dim=1)
    if torch.cuda.is_available():
        return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    return (pred == label).type(torch.FloatTensor).mean().item()

class Timer():
    """The class for timer."""
    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)

_utils_pp = pprint.PrettyPrinter()


def pprint(x):
    _utils_pp.pprint(x)


def compute_confidence_interval(data):
    """The function to calculate the .
    Args:
      data: input records
      label: ground truth labels.
    Return:
      m: mean value
      pm: confidence interval.
    """
    a = 1.0 * np.array(data)
    m = np.mean(a)
    std = np.std(a)
    pm = 1.96 * (std / np.sqrt(len(a)))
    return m, pm


def get_cos_similar_matrix(v1, v2):
    num = np.dot(v1, v2.T)
    denom = np.linalg.norm(v1, axis=1).reshape(-1, 1) * np.linalg.norm(v2, axis=1)
    res = num / denom
    res[np.isneginf(res)] = 0
    return 0.5 + 0.5 * res

def route_plan(Dij):
    K = Dij.shape[0]
    model = lp.LpProblem(name='plan_0_1', sense=lp.LpMinimize)
    x = [[lp.LpVariable("x{}{}".format(i, j), cat="Binary") for j in range(K)] for i in range(K)]
    # objective
    objective = 0
    for i in range(K):
        for j in range(K):
            objective = objective + Dij[i, j] * x[i][j]
    model += objective
    # constraints
    for i in range(K):
        in_degree = 0
        for j in range(K):
            in_degree = in_degree + x[i][j]
        model += in_degree == 1

    for i in range(K):
        out_degree = 0
        for j in range(K):
            out_degree = out_degree + x[j][i]
        model += out_degree == 1

    model.solve(lp.apis.PULP_CBC_CMD(msg=False))

    W = np.zeros((K, K))
    i = 0
    j = 0
    for v in model.variables():
        W[i, j] = v.varValue
        j = j + 1
        if j % K == 0:
            i = i + 1
            j = 0
    return W