import errno
import os
import numpy as np
import pandas as pd
import torch.utils.data as data


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def mkdir_p(path):
    '''make dir if not exist'''
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def print_acc_conf(infer_np, test_labels):
    cal = np.zeros(19)
    calacc = np.zeros(19)
    conf_metric = np.max(infer_np, axis=1)
    conf_metric_ind = np.argmax(infer_np, axis=1)
    conf_avg = np.mean(conf_metric)
    acc_avg = np.mean(conf_metric_ind == test_labels)
    print("Total data: {:d}. Average acc: {:.4f}. Average confidence: {:.4f}.".format(len(infer_np), acc_avg, conf_avg))

    return acc_avg, conf_avg


class binarydata(data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        img = self.data[index]
        label = self.labels[index]

        return img, label

    def __len__(self):
        return len(self.labels)


class TrainRecorder:
    def __init__(self):
        self.training_loss_list = []
        self.training_acc_list = []
        self.train_loss_list = []
        self.train_acc_list = []
        self.test_loss_list = []
        self.test_acc_list = []

    def update(self, train_loss=0, train_acc=0, test_loss=0, test_acc=0, training_loss=0, training_acc=0):
        self.train_loss_list += [train_loss]
        self.train_acc_list += [train_acc]
        self.test_loss_list += [test_loss]
        self.test_acc_list += [test_acc]
        self.training_loss_list += [training_loss]
        self.training_acc_list += [training_acc]

    def save(self, path, filename):
        data = []
        for i in range(len(self.train_loss_list)):
            data.append([
                self.train_loss_list[i], self.train_acc_list[i],
                self.test_loss_list[i], self.test_acc_list[i],
                self.training_loss_list[i], self.training_acc_list[i]
            ])
        df = pd.DataFrame(data, index=[i + 1 for i in range(len(self.train_loss_list))],
                          columns=['train_loss', 'train_acc', 'test_loss', 'test_acc', 'training_loss', 'training_acc'])
        if not os.path.isdir(path):
            mkdir_p(path)
        print(df)
        filepath = os.path.join(path, filename)
        pd.DataFrame(df).to_csv(filepath)
