import torch
import config
import os
import numpy as np

def get_basic_dir(model_name, dataset_name, phase):
    model_dataset = model_name + '_' + dataset_name
    output_dir = os.path.join(config.output_dir, model_dataset)
    
    basic_dir = os.path.join(output_dir, phase)
    if not os.path.exists(basic_dir):
        os.makedirs(basic_dir)
    
    return basic_dir

class ClassAccuracy:
# 计算每个类别的精度
    def __init__(self, num_classes):
        self.sum = np.zeros(num_classes)
        self.count = np.zeros(num_classes)

    def accuracy(self, outputs, labels):
        _, pred = outputs.max(dim=1)
        correct = pred.eq(labels)
        for b, label in enumerate(labels):
            self.count[label] += 1
            self.sum[label] += correct[b]

    def __str__(self):
        fmtstr = '{}:{:6.2f}'
        avg = (self.sum / self.count) * 100
        result = '\n'.join([fmtstr.format(l, a) for l, a in enumerate(avg)])
        return result

class AverageStorage(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self._reset()

    def _reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 1

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name}[AVG:{avg' + self.fmt + '}]'
        return fmtstr.format(**self.__dict__)
     
class TensorStorage(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self._reset()

    def _reset(self):
        self.val = torch.zeros(32).to(device=torch.device("cuda:0"))
        self.avg = torch.zeros(32).to(device=torch.device("cuda:0"))
        self.sum = torch.zeros(32).to(device=torch.device("cuda:0"))
        self.count = torch.tensor(0).to(device=torch.device("cuda:0"))
        

    def update(self, val):
        self.val = val
        self.sum = self.sum.detach().clone()
        self.sum =  self.sum + val
        self.count = self.count + 1
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name}[AVG:{avg' + self.fmt + '}]'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    # 展示
    def __init__(self, total, step, prefix, meters):
        self._fmtstr = self._get_fmtstr(total)
        self.meters = meters
        self.prefix = prefix
        self.total = total

        self.step = step

    def display(self, running):
        if (running + 1) % self.step == 0 or (running + 1) == self.total:
            entries = [self.prefix + self._fmtstr.format(running + 1)]  # [prefix xx.xx/xx.xx]
            entries += [str(meter) for meter in self.meters]
            print('  '.join(entries))

    def _get_fmtstr(self, total):
        num_digits = len(str(total // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(total) + ']'  # [prefix xx.xx/xx.xx]
    

def accuracy(outputs, labels):
    with torch.no_grad():
        batch_size = labels.size(0)
        pred = outputs.argmax(dim=1)

        correct = (pred == labels).sum()
        val = correct / batch_size
        return val