import json
import random
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.datasets import cifar10
import tensorflow as tf
from tensorflow.keras.applications import ResNet50  # Replace with your desired model
from tensorflow.keras.applications.resnet50 import preprocess_input
from sklearn.preprocessing import OneHotEncoder

import cfg


def create():
    # set random seed
    random.seed(0)
    np.random.seed(0)

    # save metadata
    x = save_metadata()

    # save embeddings
    save_embeddings(x)


def save_metadata():
    # load data
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()

    # stack arrays
    x = np.vstack([x_train, x_test])
    y = np.vstack([y_train, y_test])

    # One-hot encoding
    encoder = OneHotEncoder(sparse_output=False)
    y_one_hot = encoder.fit_transform(y).astype(int)

    # create df
    labels = encoder.categories_[0]
    columns = [f'{cfg.label_prefix}{int(l)}' for l in labels]
    df = pd.DataFrame(y_one_hot, columns=columns)

    # add subset column
    n_train = len(y_train)
    n_test = len(y_test)
    df['subset'] = [cfg.tag_unlabelled] * n_train + [cfg.tag_evaluate] * n_test

    # save df
    df.to_csv(Path(cfg.path_data, 'cifar10', 'metadata.csv'), index=False)  # save df

    return x


def save_embeddings(x_original):
    # resize data
    x = np.array([cv2.resize(img, (224, 224)) for img in x_original])

    # load model
    model = ResNet50(weights='imagenet', include_top=False, pooling='avg')

    # get predictions
    batch_size = 1000
    x_features_list = []

    for i in range(0, len(x), batch_size):
        print(f'Cifar10: Process batch {i}-{i + batch_size} of {len(x)}')
        x_batch = x[i:i + batch_size]
        x_preprocessed = preprocess_input(x_batch)
        features_batch = model.predict(x_preprocessed)
        x_features_list.append(features_batch)

    x_features = np.vstack(x_features_list)

    # save embeddings
    np.save(Path(cfg.path_data, 'cifar10', 'data.npy'), x)
    np.save(Path(cfg.path_data, 'cifar10', 'data_embedding.npy'), x_features)


if __name__ == '__main__':
    create()
