import cv2  # still used to save images out
import os
import numpy as np
from decord import VideoReader
from decord import cpu, gpu
import json
import numpy as np
from natsort import natsorted
import imageio
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.models import Model


def extract_frames(video_path, frames_dir, overwrite=False, start=-1, end=-1, every=1):
    """
    Extract frames from a video using decord's VideoReader
    :param video_path: path of the video
    :param frames_dir: the directory to save the frames
    :param overwrite: to overwrite frames that already exist?
    :param start: start frame
    :param end: end frame
    :param every: frame spacing
    :return: count of images saved
    """

    video_path = os.path.normpath(video_path)  # make the paths OS (Windows) compatible
    frames_dir = os.path.normpath(frames_dir)  # make the paths OS (Windows) compatible

    # dsize
    dsize = (224, 224)

    video_dir, video_filename = os.path.split(video_path)  # get the video path and filename from the path

    assert os.path.exists(video_path)  # assert the video file exists

    # load the VideoReader
    vr = VideoReader(video_path, ctx=cpu(0))  # can set to cpu or gpu .. ctx=gpu(0)
                     
    if start < 0:  # if start isn't specified lets assume 0
        start = 0
    if end < 0:  # if end isn't specified assume the end of the video
        end = len(vr)

    frames_list = list(range(start, end, every))
    saved_count = 0


    if every > 25 and len(frames_list) < 1000:  # this is faster for every > 25 frames and can fit in memory
        frames = vr.get_batch(frames_list).asnumpy()

        for index, frame in zip(frames_list, frames):  # lets loop through the frames until the end
            # resize image
            output = cv2.resize(frame, dsize)
            save_path = os.path.join(frames_dir, video_filename, "{:010d}.jpg".format(index))  # create the save path
            if not os.path.exists(save_path) or overwrite:  # if it doesn't exist or we want to overwrite anyways
                cv2.imwrite(save_path, cv2.cvtColor(output, cv2.COLOR_RGB2BGR))  # save the extracted image
                saved_count += 1  # increment our counter by one

    else:  # this is faster for every <25 and consumes small memory
        for index in range(start, end):  # lets loop through the frames until the end
            frame = vr[index]  # read an image from the capture

            # resize image
            output = cv2.resize(frame.asnumpy(), dsize)

            if index % every == 0:  # if this is a frame we want to write out based on the 'every' argument
                save_path = os.path.join(frames_dir, video_filename, "{:010d}.jpg".format(index))  # create the save path
                if not os.path.exists(save_path) or overwrite:  # if it doesn't exist or we want to overwrite anyways
                    cv2.imwrite(save_path, cv2.cvtColor(output, cv2.COLOR_RGB2BGR))  # save the extracted image
                    saved_count += 1  # increment our counter by one

    return saved_count  # and return the count of the images we saved


def video_to_frames(video_path, frames_dir, overwrite=False, every=1):
    """
    Extracts the frames from a video
    :param video_path: path to the video
    :param frames_dir: directory to save the frames
    :param overwrite: overwrite frames if they exist?
    :param every: extract every this many frames
    :return: path to the directory where the frames were saved, or None if fails
    """

    video_path = os.path.normpath(video_path)  # make the paths OS (Windows) compatible
    frames_dir = os.path.normpath(frames_dir)  # make the paths OS (Windows) compatible

    video_dir, video_filename = os.path.split(video_path)  # get the video path and filename from the path

    # make directory to save frames, its a sub dir in the frames_dir with the video name
    os.makedirs(os.path.join(frames_dir, video_filename), exist_ok=True)
    
    print("Extracting frames from {}".format(video_filename))
    
    extract_frames(video_path, frames_dir, every=every)  # let's now extract the frames

    return os.path.join(frames_dir, video_filename)  # when done return the directory containing the frames


def extract_features(frames_dir, filename, feature_extractor):

    if (feature_extractor == "ResNet50V2"):
        from tensorflow.keras.applications import ResNet50V2
        from tensorflow.keras.applications.resnet_v2 import preprocess_input
        model = ResNet50V2(weights='imagenet', include_top = False)

    elif (feature_extractor == "ResNet101V2"):
        from tensorflow.keras.applications import ResNet101V2
        from tensorflow.keras.applications.resnet_v2 import preprocess_input
        model = ResNet101V2(weights='imagenet', include_top = False)


    elif (feature_extractor == "InceptionV3"):
        from tensorflow.keras.applications import InceptionV3
        from tensorflow.keras.applications.inception_v3 import preprocess_input
        model = InceptionV3(weights='imagenet', include_top = False)


    vid_path = os.path.join(frames_dir, filename)
    frames = []
    
    for framename in natsorted(os.listdir(vid_path)):
        frame_path = os.path.join(vid_path, framename)
        frames.append(imageio.imread(frame_path))
    X = np.stack(frames, axis=0)
    X = preprocess_input(X)
    features = model.predict(X)
    features = np.max(features, axis=(1,2))

    return features


def get_attention(model, layer_name, input_data):

    intermediate_layer_model = Model(inputs=model.input,
                                     outputs=model.get_layer(layer_name).output)
    intermediate_output = intermediate_layer_model.predict(input_data)

    return intermediate_output


def visualize_concepts(test_input, model, concepts_text, inv_class_dict):
    
    test_input = np.expand_dims(test_input, axis=0)
    pred = model.predict(test_input)
    pred_label = np.argmax(pred[1],axis=1)[0]
    pred_class = inv_class_dict[pred_label]
    pred_concepts = np.where(pred[0]>=0.5)
    print(f'Predicted Concepts: {pred_concepts[1]}')
    print(f'Predicted Class: {pred_class}')
    attention = np.squeeze(get_attention(model, 'attn_score', test_input))
    pred_attn = attention[pred_concepts[1]]
    pred_text = concepts_text['text'].iloc[pred_concepts[1]]
    
#     plt.rcdefaults()
    plt.style.use('seaborn-whitegrid')
    plt.rcParams.update({'font.size': 18})
    fig, ax = plt.subplots(figsize=(6,5))

    y_pos = np.arange(len(pred_text))
    ax.barh(y_pos, pred_attn, align='center')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(pred_text, fontsize=16)
    ax.invert_yaxis()  # labels read top-to-bottom
    ax.set_xlabel('Concept Score', fontsize=20)
    ax.set_title(f'Predicted Activity: {pred_class}', fontsize=20)
    plt.show()

    return pred, pred_label, pred_concepts, pred_attn
