import json
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from tensorflow.keras.applications import ResNet50  # Replace with your desired model
from tensorflow.keras.applications.resnet50 import preprocess_input


import cfg


def create():
    # set random seed
    random.seed(0)
    np.random.seed(0)

    # create metadata
    df = save_metadata()

    # save embeddings
    save_embeddings(df)


def save_metadata():
    # load training and evaluation annotations
    with open(Path(cfg.path_label_mscoco / 'instances_train2017.json')) as f:
        instances = json.load(f)

    images = instances['images']
    annotations = instances['annotations']
    categories = instances['categories']

    # get categories
    category_map = {category['id']: category['name'] for category in categories}
    all_categories = list(category_map.values())

    # init dict
    image_categories = {image['id']: {category: 0 for category in all_categories} for image in images}

    # set categories to 1
    for annotation in annotations:
        image_id = annotation['image_id']
        category_id = annotation['category_id']
        category_name = category_map[category_id]
        image_categories[image_id][category_name] = 1

    # create dataframe
    df_data = []
    for image in images:
        image_id = image['id']
        row = {'fname': image['file_name']}

        # Setze für jede Kategorie den Wert 1 oder 0, je nach Vorhandensein
        for category in all_categories:
            row[category] = image_categories[image_id][category]

        df_data.append(row)

    df_label = pd.DataFrame(df_data, columns=['fname'] + all_categories)
    df_label.columns = [cfg.label_prefix + col if col in all_categories else col for col in df_label.columns]
    df_label['subset'] = np.where(np.random.rand(len(df_label)) < 0.2, cfg.tag_evaluate, cfg.tag_unlabelled)

    # save metadata
    df_label = df_label[sorted(df_label.columns)]  # sort columns
    numeric_cols = df_label.select_dtypes(include='number').columns  # make label columns to int
    df_label[numeric_cols] = df_label[numeric_cols].astype(int)  # make label columns to int
    df_label.to_csv(Path(cfg.path_data, 'mscoco', 'metadata.csv'), index=False)  # save df

    return df_label


def save_embeddings(df):
    # init outputs
    data_raw = np.zeros((len(df), 224, 224, 3), dtype=np.uint8)
    data_embedded = np.zeros((len(df), 2048), dtype='float32')

    # create empty batch
    batch = []

    # load model
    model = ResNet50(weights='imagenet', include_top=False, pooling='avg')

    for index, row in df.iterrows():
        print(f'PREPROCESSING MSCOCO: Row {index} / {len(df)}')

        # load image
        img_path = Path(cfg.path_data_mscoco_train, row['fname'])
        img = cv2.imread(img_path)
        img_resized = cv2.resize(img, (224, 224))

        # create batch of images
        batch.append(img_resized)

        # process batches of 1000
        if index % 1000 == 5 or index == len(df)-1:
            # convert to numpy array
            batch_numpy = np.array(batch)

            # preprocess and embed batch
            batch_preprocessed = preprocess_input(batch_numpy)
            batch_features = model.predict(batch_preprocessed)

            # save raw image and embeddings
            index_end = index + 1
            index_start = index + 1 - len(batch_numpy)

            data_raw[index_start:index_end] = batch_numpy
            data_embedded[index_start:index_end] = batch_features

            # reset batch
            batch = []

    np.save(Path(cfg.path_data, 'mscoco', 'data.npy'), data_raw)
    np.save(Path(cfg.path_data, 'mscoco', 'data_embedding.npy'), data_embedded)


if __name__ == '__main__':
    create()
