import torch
from torch.utils.data import DataLoader
import horovod
import horovod.torch as hvd


def metric_average(val, name):
    tensor = torch.tensor(val)
    avg_tensor = hvd.allreduce(tensor, name=name)
    return avg_tensor.item()

def valid(dataloader, sampler, model, args, num_class, L):
    model.eval()
    acc_number = 0.
    loss = 0.
    accuracy = 0.
    with torch.no_grad():
        for _, batch in enumerate(dataloader):
            images, labels = batch
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            labels_cal = torch.nn.functional.one_hot(labels, num_class).type(torch.float32).cuda()
            loss += L(outputs, labels_cal)
            
            if categories:= outputs.shape[1] <= args.top_k:
                print('It\'s meaningless to compute top{0:top_k} accuracy on a dataset with {1:categories} \
                    categories.'.format(top_k = args.top_k, categories = categories))
                return 1.0
            else:
                for __ in range(0, args.top_k):
                    if outputs.shape[0]:
                        idxmax = outputs.argmax(dim = 1)
                        idxeq = idxmax.eq(labels)
                        acc_number += idxeq.sum().item()
                        outputs[torch.arange(0, outputs.shape[0], 1), idxmax] = -1
                        idxeq = (idxeq == False)
                        outputs, labels = outputs[idxeq], labels[idxeq]
                        
        accuracy = acc_number / len(sampler)
        loss /= len(sampler)
                
        # Horovod: average metric values across workers.
        loss = metric_average(loss, 'avg_loss')
        accuracy = metric_average(accuracy, 'avg_accuracy')
        
        return loss, accuracy
              
def valid_old(model, v_dataset, arg, L, num_class, aug):
    data_loaders = DataLoader(v_dataset, batch_size = arg.batchsize, shuffle = True)
    acc_number = 0
    loss = 0.
    with torch.no_grad():
        for id, batch in enumerate(data_loaders):
            images, labels = batch
            images = images.cuda()
            labels = labels.cuda()
            outputs = model(images)
            labels_cal = torch.nn.functional.one_hot(labels, num_class).type(torch.float32).cuda()
            loss += L(outputs, labels_cal)
            
            topology_loss = 0.
            if arg.topology != 0:
                for i in range(arg.topology):
                    # dxi = args.Q * 2 * (torch.rand(images.shape) - 0.5).to(images.device) + images
                    ima_aug = aug(images)
                    topology_loss += L(model(ima_aug), outputs)
#                     if args.experince_aug == 1:
#                         aug_loss += lossf(model(ima_aug), labels_cal)
                topology_loss = topology_loss / arg.topology
            
            if categories:= outputs.shape[1] <= arg.top_k:
                print('It\'s meaningless to compute {0:top_k} accuracy on a dataset with {1:categories} \
                    categories.'.format(top_k = arg.top_k, categories = categories))
            outputs = model(images)
            if categories:= outputs.shape[1] <= arg.top_k:
                print('It\'s meaningless to compute top{0:top_k} accuracy on a dataset with {1:categories} \
                    categories.'.format(top_k = arg.top_k, categories = categories))
                return 1.0
            else:
                #outputs = outputs.copy()
                for i in range(0, arg.top_k):
                    if outputs.shape[0]:
                        idxmax = outputs.argmax(dim = 1)
                        idxeq = idxmax.eq(labels)
                        acc_number += idxeq.sum().item()
                        outputs[torch.arange(0, outputs.shape[0], 1), idxmax] = -1
                        idxeq = (idxeq == False)
                        outputs, labels = outputs[idxeq], labels[idxeq]
    loss = loss/(id + 1)
    topology_loss = topology_loss/(id + 1)
    loss = loss + topology_loss
    accuracy = acc_number / len(v_dataset)
    return accuracy, loss, topology_loss
        
def test_time_aug(dataloader, sampler, model, aug, args, num_class, L):
    model.eval()
    acc_number = 0.
    loss = 0.
    acc = 0.
    with torch.no_grad():
        for _, batch in enumerate(dataloader):
            images, labels = batch
            images, labels = images.cuda(), labels.cuda()
            outputs = 0.5 * model(images) + 0.5 * model(aug(images))
            labels_cal = torch.nn.functional.one_hot(labels, num_class).type(torch.float32).cuda()
            loss += L(outputs, labels_cal)
            
            if categories:= outputs.shape[1] <= args.top_k:
                print('It\'s meaningless to compute top{0:top_k} accuracy on a dataset with {1:categories} \
                    categories.'.format(top_k = args.top_k, categories = categories))
                return 1.0
            else: 
                for __ in range(0, args.top_k):
                    if outputs.shape[0]:
                        idxmax = outputs.argmax(dim = 1)
                        idxeq = idxmax.eq(labels)
                        acc_number += idxeq.sum().item()
                        outputs[torch.arange(0, outputs.shape[0], 1), idxmax] = -1
                        idxeq = (idxeq == False)
                        outputs, labels = outputs[idxeq], labels[idxeq]
                        
        acc = acc_number / len(sampler)
        loss /= len(sampler)
        
        # Horovod: average metric values across workers.
        loss = metric_average(loss, 'avg_loss')
        acc = metric_average(acc, 'avg_accuracy')
                
        return loss, acc