import torch

from algorithms.space.base_space import IdentifiableSpace

from algorithms.space.basic_callable_space import BasicCallableSpace


class VAESpace(BasicCallableSpace, IdentifiableSpace):
    def __init__(self, vae, classifier, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vae = vae
        self.classifier = classifier

    @property
    def callable_env(self):
        def func(images):
            landmarks = torch.stack([self.landmarks(img) for img in images])
            r_landmark = self.landmark_loss(landmarks)

            attributes = self.attributes(images).detach()
            disc = self.discriminator(images).detach()

            r_disc = -torch.tanh(disc).detach()
            r_att = (
                self.att_loss(attributes, self.attributes_target.repeat(len(attributes), 1))
                .detach()
                .mean(1)
            )

            return r_landmark * 100 + r_disc * 2 + r_att

        return func
