# flake8: noqa
# type: ignore


print('args.backbone', args.backbone)  # vgg_fixed
backbone_model = None

image_model = ExWrapper(ImageRep(backbone_model)).to(device)

print('args.comparison', args.comparison)  # dotp
scorer_model = DotPScorer().to(device)

print('use_hyp')
embedding_model = nn.Embedding(train_vocab_size, 512)
print('embedding_model', embedding_model)

print('decode_hyp')
proposal_model = TextProposal(embedding_model).to(device)
print('proposal_model', proposal_model)

print('encode_hyp')
hint_model = TextRep(embedding_model).to(device)
print('hint_model', hint_model)

    def train(epoch, n_steps=100):
        image_model.train()
        scorer_model.train()
        if args.decode_hyp:
            proposal_model.train()
        if args.encode_hyp:
            hint_model.train()
        if args.multimodal_concept:
            multimodal_model.train()

        loss_total = 0
        pbar = tqdm(total=n_steps)
        for batch_idx in range(n_steps):
            examples, image, label, hint_seq, hint_length, *rest = \
                train_dataset.sample_train(args.batch_size)
            print('examples.size()', examples.size(), 'image.size()', image.size(), 'label.size()', label.size(), 'hint_seq.size()', hint_seq.size(), 'hint_length.size()', hint_length.size())

            examples = examples.to(device)
            image = image.to(device)
            label = label.to(device)
            batch_size = len(image)
            n_ex = examples.shape[1]

            # Load hint
            hint_seq = hint_seq.to(device)
            hint_length = hint_length.to(device)
            max_hint_length = hint_length.max().item()
            # Cap max length if it doesn't fill out the tensor
            if max_hint_length != hint_seq.shape[1]:
                hint_seq = hint_seq[:, :max_hint_length]

            # Learn representations of images and examples
            image_rep = image_model(image)
            examples_rep = image_model(examples)
            examples_rep_mean = torch.mean(examples_rep, dim=1)

            # Prediction loss
            # supervised learning of receiver network:
            # given hypothesis and receiver image, calculate loss against receiver label
            hint_rep = hint_model(hint_seq, hint_length)
            score = scorer_model.score(hint_rep, image_rep)
            # pred loss is receiver supervised loss
            pred_loss = F.binary_cross_entropy_with_logits(score, label.float())

            # Hypothesis loss
            print('use_hyp')
            # How plausible is the true hint under example/image rep?
            print('no predict_image_hyp')

            # Decode images/examples to hints
            # basically sender model supervised training
            # examples_rep_mean => hypothesis (hypo_out)
            # hypo_out, hint_seq => sender supervised loss
            # *** hint_seq is provided as teacher forcing ***
            hypo_out = proposal_model(examples_rep_mean, hint_seq, hint_length)
            seq_len = hint_seq.size(1)
            hypo_out = hypo_out[:, :-1].contiguous()
            hint_seq = hint_seq[:, 1:].contiguous()

            hypo_out_2d = hypo_out.view(batch_size * (seq_len - 1),
                                        train_vocab_size)
            hint_seq_2d = hint_seq.long().view(batch_size * (seq_len - 1))
            print('calculate hypo_loss from hypo_out_2d and hint_seq_2d')
            hypo_loss = F.cross_entropy(hypo_out_2d,
                                        hint_seq_2d,
                                        reduction='none')
            hypo_loss = hypo_loss.view(batch_size, (seq_len - 1))
            print('calculate hypo_loss as mean of hypo_loss')
            hypo_loss = torch.mean(torch.sum(hypo_loss, dim=1))

            loss = args.pred_lambda * pred_loss + args.hypo_lambda * hypo_loss
