import time
import logging
import torch
import torch.nn as nn
from apex import amp
from tools.utils import AverageMeter
    
def train(config, epoch, model, Generator, discriminator, classifier, criterion_cla, criterion_cal, 
        criterion_mm, criterion_cicl, criterion_tri, optimizer, optimizer_cc, trainloader, pid2clothes):
    logger = logging.getLogger('reid.train')
    batch_cla_loss = AverageMeter()
    batch_dise_loss = AverageMeter()
    batch_cicl_loss = AverageMeter()
    batch_int_loss = AverageMeter()
    batch_gan_loss = AverageMeter()
    batch_tri_loss = AverageMeter()
    batch_mm_loss = AverageMeter()
    batch_disc_loss = AverageMeter()
    corrects = AverageMeter()
    clothes_corrects = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    adv_loss = nn.BCEWithLogitsLoss()
    cyc_loss = nn.MSELoss()
    model.train()
    classifier.train()
    Generator.train()
    discriminator.train()
    end = time.time()
    softmax = nn.Softmax(dim=1)
    for batch_idx, (imgs, imgs_b, pids, camids, clothes_ids) in enumerate(trainloader):
        # Get all positive clothes classes (belonging to the same identity) for each sample
        pos_mask = pid2clothes[pids]
        pos_mask = pos_mask.float().cuda()
        imgs, imgs_b, pids, camids, clothes_ids = imgs.cuda(), imgs_b.cuda(), pids.cuda(), camids.cuda(), clothes_ids.cuda()
        data_time.update(time.time() - end)
        tuple_features = model(torch.cat((imgs, imgs_b), dim=0))    
        features, features2 = tuple_features[1].split(imgs.size(0), dim=0)
        clo_feat = tuple_features[2]
        clo_score = tuple_features[3]
        mask_feat = tuple_features[4]
        out1_2, out2_1, out1_1, out2_2 = Generator(features.detach(), clo_feat.detach())
        outputs = classifier(features)
        outputs2 = classifier(features2)
        cla_loss = criterion_cla(outputs, pids) 
        black_loss = criterion_cla(outputs2, pids)
        adv3_loss = criterion_cicl(features, features2, pids) 
        disc_loss = adv_loss(discriminator(features2.detach()), torch.ones(pids.size(0), 1).cuda()) + \
                    adv_loss(discriminator(clo_feat.detach()), torch.zeros(pids.size(0), 1).cuda()) +\
                    adv_loss(discriminator(out1_2.detach()), torch.ones(pids.size(0), 1).cuda()) + \
                    adv_loss(discriminator(out2_1.detach()), torch.zeros(pids.size(0), 1).cuda())   
        if epoch % 7 <= 1 :
            optimizer_cc.zero_grad()
            if config.TRAIN.AMP:
                with amp.scale_loss(disc_loss, optimizer_cc) as scaled_loss:
                    scaled_loss.backward()
            else:
                disc_loss.backward()
            optimizer_cc.step()
        int_loss =  criterion_cal(mask_feat, clothes_ids) + criterion_cla(clo_score,camids)
        with torch.no_grad():  # no gradient to keys
            k = classifier(features2)  # keys: NxC
            p_feat = softmax(k.detach())
            true_label_probs = p_feat[range(p_feat.shape[0]), pids]
        cicl_loss = torch.mean(adv3_loss * true_label_probs)
        _, preds = torch.max(outputs.data, 1)
        _, preds2 = torch.max(outputs2.data, 1)
        mm_loss = criterion_mm(features,pids)+criterion_mm(features2,pids)
        new_feat = classifier(features2)-classifier(out2_1)
        dise_loss = (criterion_cla(new_feat, pids)) 
        tri_loss = criterion_tri(features2, pids) + criterion_tri(features, pids)
        gan_loss = adv_loss(discriminator(out2_1), torch.ones(pids.size(0), 1).cuda()) + \
                   adv_loss(discriminator(out1_2), torch.zeros(pids.size(0), 1).cuda()) + \
                   cyc_loss(out1_1, features2.detach()) + cyc_loss(out2_2, clo_feat.detach()) 
        loss = cla_loss +  black_loss + tri_loss + (mm_loss+cicl_loss)*config.Hyper.beta + int_loss*config.Hyper.alpth
        if epoch % 7 > 1:
            loss+= gan_loss*0.2 + dise_loss*config.Hyper.eta
        optimizer.zero_grad()
        if config.TRAIN.AMP:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        # statistics
        corrects.update(torch.sum(preds2 == pids.data).float()/pids.size(0), pids.size(0))
        clothes_corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))
        batch_cla_loss.update(cla_loss.item(), pids.size(0))
        batch_dise_loss.update(dise_loss.item(), pids.size(0))
        batch_cicl_loss.update(cicl_loss.item(), pids.size(0))
        batch_int_loss.update(int_loss.item(), pids.size(0))
        batch_gan_loss.update(gan_loss.item(), pids.size(0))
        batch_tri_loss.update(tri_loss.item(), pids.size(0))
        batch_mm_loss.update(mm_loss.item(), pids.size(0))
        batch_disc_loss.update(disc_loss.item(), pids.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    logger.info('Epoch{0} '
                  'Time:{batch_time.sum:.1f}s '
                  'Data:{data_time.sum:.1f}s '
                  'ClaLoss:{cla_loss.avg:.4f} '
                  'DiseLoss:{dise_loss.avg:.4f} '
                  'ciclLoss:{cicl_loss.avg:.4f} '
                  'IntLoss:{int_loss.avg:.4f} '
                  'GANloss:{gan_loss.avg:.4f} ' 
                  'TriLoss:{tri_loss.avg:.4f} '
                  'Acc:{acc.avg:.2%} '
                  'CloAcc:{clo_acc.avg:.2%} '.format(
                   epoch+1, batch_time=batch_time, data_time=data_time, 
                   cla_loss=batch_cla_loss, dise_loss=batch_dise_loss,
                   cicl_loss=batch_cicl_loss, 
                   int_loss=batch_int_loss, 
                   gan_loss=batch_gan_loss, tri_loss=batch_tri_loss, 
                   mm_loss=batch_mm_loss, disc_loss=batch_disc_loss, 
                   acc=corrects, clo_acc=clothes_corrects))
