from network.train_functions_global import rec_loss, affinity_loss

#1. Preprocess data from the dataset def get_data(data, training_params):
def get_data(data):
    # Return features and labels
    if len(data)>2:
        return data[0],data[1], data[2]
    else:
        return data[0], None, data[1]
#2. Compute z and other stuff by passing the observation through the encoder/decoder
def get_z(data, ae, ae_id):
    o = data
    o_pred, z = ae(o, ae_id)
    return o, o_pred, z

#2. Return the label for the downstream task
def get_label(data):
    return data['label']

def get_l_probs(z, classifier):
    l_probs = classifier(z)
    return l_probs

#4. Loss function
def loss_function_ae(model, z1, z2, l1, l2, idxToUse, C=0, map=None, alpha=0, scale=False, use_target=False, device='cpu'):
    o_1, o_pred_1, z_1 = z1
    o_2, o_pred_2, z_2 = z2

    # Reconstruction loss term
    rec_1 = rec_loss(o_1, o_pred_1)
    rec_2 = rec_loss(o_2, o_pred_2)
    rec = 0.5 * (rec_1 + rec_2)

    # Affinity score
    if model.training_params['alpha'] == 0:
        loss = rec
        return loss, (rec_1.detach().cpu().numpy(), rec_2.detach().cpu().numpy(), 0, 0, 0), None, None
    else:
        affinity = affinity_loss(z_1, z_2, l1, l2, idxToUse, origC=C, map=map, alpha=alpha, scale=scale, train=use_target, device=device)
        a_loss = model.training_params['alpha'] * affinity[0]
        loss = rec + a_loss
        return loss, (rec_1.detach().cpu().numpy(), rec_2.detach().cpu().numpy(), \
                affinity[0].detach().cpu().numpy()), affinity[1], affinity[2]

#5. Housekeeping
def setup_housekeeping():
    hk_loss_rec = []
    hk_loss_slow = []
    return [hk_loss_rec, hk_loss_slow]

def housekeeping(hk_data, hk_lists):
    rec_loss, kld_loss, slow_loss = hk_data
    hk_lists[0].append(rec_loss)
    hk_lists[2].append(slow_loss)
    return hk_lists
