import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from dis_lib_andreea.data_methods.data.ground_truth import ground_truth_data

RESOLUTION = [64, 64]


def unpack_features(ds, brands, brand_dict, years, year_dict):
    car_models = []
    car_brands = []
    car_years = []
    for datapoint in ds:
        car_model_id = datapoint["label"]
        car_models.append(car_model_id)
        car_brands.append(brand_dict[brands[car_model_id]])
        car_years.append(year_dict[years[car_model_id]])
    return np.stack([car_models, car_brands, car_years], axis=-1)


def crop_norm_image(dictionary_):
    """
    This is the mapping that takes the tf.DataSet crops the images w.r.t bounding boxes and resizes them
    :param dictionary_:
    :return:
    """
    print(dictionary_)
    image = dictionary_["image"]
    bbox = dictionary_["bbox"]
    label = dictionary_["label"]
    resized_image = tf.image.crop_and_resize(tf.expand_dims(image, 0), [bbox], [0], RESOLUTION)[0]
    return tf.cast(resized_image, tf.float32) / 255., label

class StanfordCars:
    def __init__(self, split):
        ds, ds_info = tfds.load("cars196", split=split, with_info=True)

        # Read the actual string values per class this is the complete class name per unique car
        car_model_name = ds_info.features["label"].names
        # Separate the strings into brand and year per image
        brand_per_model = [category.split(" ")[0] for category in car_model_name]
        year_per_model = [category.split(" ")[-1] for category in car_model_name]

        # Identify the unique brands and years and create a dictionary from each
        brand_dict = {b_key: num_brand for num_brand, b_key in enumerate(np.unique(brand_per_model))}
        year_dict = {y_key: num_year for num_year, y_key in enumerate(np.unique(year_per_model))}

        self.features = unpack_features(ds, brand_per_model, brand_dict, year_per_model, year_dict)

        self.data_shape = RESOLUTION + [3]
        self.factor_sizes = np.array([len(car_model_name), len(np.unique(brand_per_model)), len(year_per_model)])

        # Add the normalization, cropping of the cars and resizing of the images to the dataset
        ds = ds.map(crop_norm_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)


        # Extract the images
        self.images = np.array([image for (image, _) in tfds.as_numpy(ds)])
        self.labels = np.array([label for (_, label) in tfds.as_numpy(ds)])
