import time
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from metrics.myAUC import AUCMeter
from utils.utils import AverageMeter
from metrics.accuracy import accuracy


def train(train_loader, model, criterion_cls, criterion_gcn, optimizer, epoch, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    losses_recon = AverageMeter()
    losses_kl = AverageMeter()
    losses_cls = AverageMeter()
    losses_gcn = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_rex_cls = 0
    loss_rex_gcn = 0
    # switch to train mode
    model.train()
    eval_auc_gcn = AUCMeter()
    eval_auc = AUCMeter()
    end = time.time()
    a = 0
    b = 0


    for i, (input, target, gcn_target, id_tar) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        # target = target.cuda(async=True)
        # target = torch.cat(item.unsqueeze(0) for item in target)

        target = target.cuda(async=True)
        target_var = target.cuda(non_blocking=True)
        input_var = input.cuda(non_blocking=True)


        gcn_target = torch.cat((gcn_target[0].unsqueeze(0), gcn_target[1].unsqueeze(0), gcn_target[2].unsqueeze(0),
                                gcn_target[3].unsqueeze(0), gcn_target[4].unsqueeze(0), gcn_target[5].unsqueeze(0),
                                gcn_target[6].unsqueeze(0), gcn_target[7].unsqueeze(0), gcn_target[8].unsqueeze(0),
                                gcn_target[9].unsqueeze(0), gcn_target[10].unsqueeze(0),
                                gcn_target[11].unsqueeze(0)), 0)
        gcn_target = gcn_target.transpose(0, 1).float()

        gcn_target = gcn_target.cuda(async=True)
        gcn_target_var = gcn_target.cuda(non_blocking=True)


        output = model(input_var, id_tar, 'train')

        loss_cls = criterion_cls(output[-1], target_var)
        loss_gcn = criterion_gcn(output[-2], gcn_target_var)
        loss_vae_dict = model.loss_function(*output, M_N=int(args.batch_size) / len(train_loader.dataset.imgs))

        loss = loss_vae_dict['loss'] + int(args.para) * loss_cls + loss_gcn
        acc1, acc5 = accuracy(output[-1], target_var, topk=(1, 1))
        # AUC
        needata = output[-1]
        _, predi = needata.topk(1, 1, True, True)
        predi = predi.view(len(predi))

        losses.update(loss.item(), input.size(0))
        losses_recon.update(loss_vae_dict['Reconstruction_Loss'].item(), input.size(0))
        losses_kl.update(loss_vae_dict['KLD'].item(), input.size(0))
        losses_cls.update(loss_cls.item(), input.size(0))
        losses_gcn.update(loss_gcn.item(), input.size(0))
        eval_auc.update(predi, target_var)
        top1.update(acc1[0], input[0].size(0))
        top5.update(acc5[0], input[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()


        if type(loss_rex_cls) == int:
            loss_rex_cls = loss_cls.detach()
            loss_rex_gcn = loss_gcn.detach()
            loss_cls_gcn = 0
        else:
            cat_rex_cls = torch.cat((loss_rex_cls.unsqueeze(0).unsqueeze(1), loss_cls.detach().unsqueeze(0).unsqueeze(1)),1)
            loss1 = torch.var(cat_rex_cls)
            loss_rex_cls = torch.mean(cat_rex_cls)

            cat_rex_gcn = torch.cat((loss_rex_gcn.unsqueeze(0).unsqueeze(1), loss_gcn.detach().unsqueeze(0).unsqueeze(1)), 1)
            loss2 = torch.var(cat_rex_gcn)
            loss_rex_gcn = torch.mean(cat_rex_gcn)
            loss_cls_gcn = loss1 + loss2

        loss = loss + loss_cls_gcn


        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t'
                  'Loss_gcn {loss_gcn.val:.4f} ({loss_gcn.avg:.4f})\t'
                  'Loss_recon {loss_recon.val:.4f} ({loss_recon.avg:.4f})\t'
                  'Loss_kl {loss_kl.val:.4f} ({loss_kl.avg:.4f})\t'
                  'AUC {AUC}\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, loss_cls=losses_cls, loss_gcn=losses_gcn, loss_recon=losses_recon,
                loss_kl=losses_kl, AUC=eval_auc.get_auc(), top1=top1))
