import torch.optim
import models.transform_layers as TL
from training.contrastive_loss import get_similarity_matrix, NT_xent
from utils import AverageMeter, normalize
from adv_training.attack import RepresentationAdv
import time
import random

device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
hflip = TL.HorizontalFlipLayer().to(device)

def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
          simclr_aug=None):
    assert simclr_aug is not None
    assert P.sim_lambda == 1.0  # to avoid mistake
    assert P.K_shift > 1

    if logger is None:
        log_ = print
    else:
        log_ = logger.log

    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses = dict()
    losses['cls'] = AverageMeter()
    losses['sim'] = AverageMeter()
    losses['shift'] = AverageMeter()
    Rep = RepresentationAdv(model, epsilon=P.epsilon, alpha=P.alpha,
                            min_val=P.min, max_val=P.max, max_iters=P.k,
                            _type=P.attack_type, loss_type=P.loss_type,
                            regularize=P.regularize_to, criterion=criterion)

    check = time.time()
    print(f"number of batch={len(loader)}")
    print("P.K_shift", P.K_shift)
    for n, (images, labels) in enumerate(loader):
        model.train()
        count = n * P.n_gpus  # number of trained samples
        data_time.update(time.time() - check)
        check = time.time()
    
        
        ### SimCLR loss ###
        if P.dataset != 'imagenet':
            batch_size = images.size(0)
            images = images.to(device)
            images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2)  # hflip
        else:
            batch_size = images[0].size(0)
            images1, images2 = images[0].to(device), images[1].to(device)
        labels = labels.to(device)

        images1 = torch.cat([P.shift_trans(images1, k) for k in range(P.K_shift)])
        images2 = torch.cat([P.shift_trans(images2, k) for k in range(P.K_shift)])
        images1, images2 = images1.to(device), images2.to(device)

        shift_labels = torch.cat([torch.ones_like(labels) * k for k in range(P.K_shift)], 0)  # B -> 4B

        images1 = simclr_aug(images1)
        images2 = simclr_aug(images2)

        # Adversarial Training on negative transformations(NAT)
        chu = images2.chunk(P.K_shift)
        a = [i for i in range(P.K_shift)]
        for i in range(0, len(a) - 1):
            pick = random.randint(i + 1, len(a) - 1)
            a[i], a[pick] = a[pick], a[i]
        # neg_img = torch.cat((chu[a[0]], chu[a[1]], chu[a[2]], chu[a[3]]))
        neg_img = torch.cat([chu[a[i]].clone() for i in range(P.K_shift)])
        
        adv_nat = Rep.get_adversarial_contrastive_img(original_images=images2, target=neg_img, optimizer=optimizer,
                                                      weight=P.lamda, random_start=P.random_start, reduce_loss=True)
        
        # robust transformation prediction (RTD)
        adv_rtd = Rep.get_adversarial_shift_trans_img(original_images=images2, optimizer=optimizer,
                                                      weight=P.lamda, shift_labels=shift_labels,
                                                      random_start=P.random_start)

        # Adversarial training on positive transformations (PAT)
        adv_pat = Rep.get_adversarial_contrastive_img(original_images=images1, target=images2, optimizer=optimizer,
                                                      weight=P.lamda, random_start=P.random_start)

        images_pair = torch.cat([images1, images2, adv_pat, adv_nat], dim=0)  # 8B
        #images_pair = torch.cat([images1, images2, adv_nat], dim=0)  # 8B

        _, outputs_aux = model(images_pair, simclr=True, penultimate=False, shift=True)
        outputs_aux['simclr'] = normalize(outputs_aux['simclr'])  # normalize

        _, outputs_shift_adv = model(adv_rtd, simclr=False, penultimate=False, shift=True)

        sim_matrix = get_similarity_matrix(outputs_aux['simclr'], multi_gpu=P.multi_gpu)
        loss_sim = NT_xent(sim_matrix, temperature=0.5, chunk=4) * P.sim_lambda
        # loss_sim = NT_xent(sim_matrix, temperature=0.5, chunk=3) * P.sim_lambda
        loss_shift = criterion(
            torch.cat([outputs_aux['shift'][:int(outputs_aux['shift'].shape[0] // 2)], outputs_shift_adv['shift']]),
            shift_labels.repeat(3))
        
        
        '''
        loss_shift = criterion(
            torch.cat([outputs_aux['shift'][:int((2*outputs_aux['shift'].shape[0]) // 3)], outputs_shift_adv['shift']]),
            shift_labels.repeat(3))
        '''
        
        '''
        ### SimCLR loss ###
        if P.dataset != 'imagenet':
            batch_size = images.size(0)
            images = images.to(device)
            images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2)  # hflip
        else:
            batch_size = images[0].size(0)
            images1, images2 = images[0].to(device), images[1].to(device)
        labels = labels.to(device)

        images1 = torch.cat([P.shift_trans(images1, k) for k in range(P.K_shift)])
        images2 = torch.cat([P.shift_trans(images2, k) for k in range(P.K_shift)])
        shift_labels = torch.cat([torch.ones_like(labels) * k for k in range(P.K_shift)], 0)  # B -> 4B
        shift_labels = shift_labels.repeat(2)

        images1, images2 = images1.to(device), images2.to(device)
        
        images_pair = torch.cat([images1, images2], dim=0)  # 8B
        images_pair = simclr_aug(images_pair)  # transform

        _, outputs_aux = model(images_pair, simclr=True, penultimate=True, shift=True)

        simclr = normalize(outputs_aux['simclr'])  # normalize
        sim_matrix = get_similarity_matrix(simclr, multi_gpu=P.multi_gpu)
        loss_sim = NT_xent(sim_matrix, temperature=0.5, chunk=2) * P.sim_lambda

        loss_shift = criterion(outputs_aux['shift'], shift_labels)
        '''
        ### total loss ###
        loss = loss_sim + loss_shift

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        scheduler.step(epoch - 1 + n / len(loader))
        lr = optimizer.param_groups[0]['lr']
        torch.cuda.empty_cache()

        batch_time.update(time.time() - check)

        losses['cls'].update(0, batch_size)
        losses['sim'].update(loss_sim.item(), batch_size)
        losses['shift'].update(loss_shift.item(), batch_size)

        if count % 50 == 0:
            log_('[Epoch %3d; %3d] [Time %.3f] [Data %.3f] [LR %.5f]\n'
                 '[LossC %f] [LossSim %f] [LossShift %f]' %
                 (epoch, count, batch_time.value, data_time.value, lr,
                  losses['cls'].value, losses['sim'].value, losses['shift'].value))

    log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossSim %f] [LossShift %f]' %
         (batch_time.average, data_time.average,
          losses['cls'].average, losses['sim'].average, losses['shift'].average))

    if logger is not None:
        logger.scalar_summary('train/loss_cls', losses['cls'].average, epoch)
        logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
        logger.scalar_summary('train/loss_shift', losses['shift'].average, epoch)
        logger.scalar_summary('train/batch_time', batch_time.average, epoch)
