

import os
import sys
import numpy as np
# PROJECT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
PROJECT_PATH = "/home/disentangling_everything"
DISENTANGLEMENT_LIB_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")

# Add paths to project and to disentanglement_lib
sys.path.append(PROJECT_PATH)
sys.path.append(DISENTANGLEMENT_LIB_PATH)
from disentanglement_lib.data.ground_truth.ground_truth_data import GroundTruthData
from disentanglement_lib.data.ground_truth import util

# Project modules
from data.data_loader import load_factor_data

class DisLibGroundTruthData(GroundTruthData):

    def __init__(self, data, **kwargs):
        self.data_name = data
        self.dataset_class = load_factor_data(data= data, **kwargs)
        self.latent_factor_indices = [num_factor for num_factor, _ in enumerate(range(self.dataset_class.n_factors))]

        features = self.dataset_class.flat_factor_mesh_enumerated
        for num_factor in range(self.num_factors):
            features[:,num_factor] = features[:,num_factor]

        self.state_space = util.SplitDiscreteStateSpace(self.factors_num_values, self.latent_factor_indices)
        self.index = util.StateSpaceAtomIndex(self.factors_num_values, features)
    @property
    def num_factors(self):
        return self.dataset_class.n_factors

    @property
    def factors_num_values(self):
        num_values = list(self.dataset_class.factors_shape)
        return num_values

    @property
    def observation_shape(self):
        return list(self.dataset_class.image_shape)

    def sample_factors(self, num, random_state):
        """Sample a batch of factors Y."""
        return self.state_space.sample_latent_factors(num, random_state)

    def sample_observations_from_factors(self, factors, random_state):
        all_factors = self.state_space.sample_all_factors(factors, random_state)
        indices = self.index.features_to_index(all_factors)
        return self.dataset_class.flat_images[indices].astype(np.float32)

    def sample(self, num, random_state):
        """Sample a batch of factors Y and observations X."""
        factors = self.sample_factors(num, random_state)
        return factors, self.sample_observations_from_factors(factors, random_state)

    def sample_observations(self, num, random_state):
        """Sample a batch of observations X."""
        return self.sample(num, random_state)[1]