import torch
import torch.nn as nn




def prepare_data(imgs, labels, domains, y, d, num=None):
    '''
    :param d: Wanted domain
    :param y: Wanted label
    '''

    idx = labels == y
    imgs = imgs[idx]
    labels = labels[idx]
    domains = domains[idx]

    idx = domains == d
    imgs = imgs[idx]
    labels = labels[idx]
    domains = domains[idx]

    if num:
        imgs = imgs[:num]
        labels = labels[:num]
        domains = domains[:num]

    return imgs, labels, domains

def prepare_data_domains(imgs, labels, domains, label, domain_list, train=True):
    xlist = []
    dlist = []
    for d in domain_list:
        xd, _, dd = prepare_data(imgs, labels, domains, label, d)
        xlist.append(xd)
        dlist.append(dd)
    x = torch.cat(xlist)
    x = x.view(x.shape[0], *imgs.shape[1:])
    d = torch.cat(dlist)
    # d = d.numpy()

    if train:
        # make the number of samples to be even
        idx = int(x.shape[0] / 2) * 2
        x = x[:idx]
        d = d[:idx]
        # d = d.numpy()

    return x, d

def inb_translate(cd, x, d, target_d):
    '''
    translate data
    '''
    z = cd(x,d)
    trans_d = torch.ones(z.shape[0]) * target_d
    x_trans = cd.inverse(z,trans_d)
    return x_trans

def sg_translate(sg, x, d, target_d):
    return sg(x, d, target_d)
