import torch


def predict_2(model, src, trg, opt):
    '''
    auto regression
    '''
    model.eval()
    with torch.no_grad():
        src_residual = src

        enc_input = model.enc_exp(src)     
        enc_input = model.enc_spa_enco(enc_input)
        enc_input = model.enc_tem_enco(enc_input)
        
        enc_output = model.encoder(enc_input)

        lines = trg.shape[2]
        
        trg = torch.zeros(trg.shape).cuda()

        for i in range(lines):
            if opt.HP['use']:
                head_pre = src_residual[:, :, -1, :].unsqueeze(-1)
                dec_input = model.trg_pro(trg, enc_output, head = head_pre)
            else:
                dec_input = model.trg_pro(trg, enc_output)
            dec_output = model.decoder(dec_input, enc_output)
            dec_output = model.dec_rdu(dec_output)
            trg[:, :, i, :] = dec_output[:, :, i, :]

        return trg