#!/usr/bin/env python


import argparse

parser = argparse.ArgumentParser("Running coached emotion recognition")
# Tried and did not success (yet) with subparsers.
parser.add_argument("input",
                    help="An input file (MPEG4, AVI, etc) or a video stream number (typically 0).")
parser.add_argument("--config", "-c",
                    default="evaluation/run_config.yml',
                    help="Configuration file for the runtime.")
args = parser.parse_args()

from datetime import datetime, timezone
import logging
import os
import yaml

import av
import cv2
import numpy as np

logging.basicConfig(level=logging.INFO)


def save_as_video(idx, batch, out_dir):
    duration = 1
    fps = len(batch)
    total_frames = duration * fps

    container = av.open(os.path.join(out_dir, f'save_{idx}.mp4'), mode='w')

    stream = container.add_stream('mpeg4', rate=fps)
    stream.height = batch.shape[3]
    stream.width = batch.shape[4]
    stream.pix_fmt = 'yuv420p'

    for frame_i in range(total_frames):
        frame_ = batch[0].select(1, frame_i).numpy().transpose((2, 1, 0))
        frame = av.VideoFrame.from_ndarray(frame_, format='rgb24')
        for packet in stream.encode(frame):
            container.mux(packet)

    # Flush stream
    for packet in stream.encode():
        container.mux(packet)
    container.close()


def run(_input, config_path):

    with open(config_path) as f:
        config = yaml.safe_load(f)

    import sys
    sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'src'))
    sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'evaluation'))

    import torch
    common_dtype = getattr(torch, config['kernel']['type'])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Sensoriplexer
    from coach.sensoriplexer import Sensoriplexer
    sp = Sensoriplexer(config['kernel']['code_size'], dtype=common_dtype, device=device)
    datasets = {}
    for name, specs in config['signals'].items():
        sp.add(name, tuple(specs['shape']))
    sp_state = torch.load(config['kernel']['state'])
    sp.load_state_dict(sp_state['model'])
    sp.eval()
    sp.to(device)

    # Downstream model
    ds_signal = list(config['downstream'].keys())[0]
    ds_specs = list(config['downstream'][ds_signal][0].values())[0]
    logging.info(f"Downstream system on the {ds_signal} signal: {ds_specs['description']}")
    ds_path = os.path.dirname(ds_specs['handler'])
    ds_module = os.path.basename(ds_specs['handler'])
    sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), ds_path))
    import importlib
    ds_handle = importlib.import_module(ds_module)
    ds = ds_handle.specs()
    from functools import partial
    downstream = partial(ds_handle.run, ds['model'])
    ds_interpret = ds['interpretation']

    # RAVDESS interpretations. TODO remove dependency on the handler interpretation
    ravdess_interpret = {
        0: 'neutral',
        1: 'calm',
        2: 'happiness',
        3: 'sadness',
        4: 'anger',
        5: 'fear',
        6: 'disgust',
        7: 'surprise',
    }

    # Output trace and artefacts
    output_dir = os.path.join('output', 'inference')
    from pathlib import Path
    Path(output_dir).mkdir(parents=True, exist_ok=True)


    # Results
    container = av.open(os.path.join(output_dir, 'run.mp4'), mode='w')

    stream = container.add_stream('mpeg4', rate=30)
    stream.height = config['signals']['video']['shape'][2]
    stream.width = config['signals']['video']['shape'][3]
    stream.pix_fmt = 'yuv420p'

    with torch.no_grad():
        with tqdm(total=5000) as pbar:
            for idx in range(5000):
                batch = dataloader.__iter__().__next__()
                #save_as_video(idx, batch['video'], output_dir)
                for idx in range(batch_size):
                    seq = batch['video'][idx]
                    for fi in range(seq.shape[1]):
                        frame_ = seq.select(1, fi).numpy().transpose((1, 2, 0))
                        frame = av.VideoFrame.from_ndarray(frame_, format='rgb24')
                        for packet in stream.encode(frame):
                            container.mux(packet)

                videos = batch['video'].type(common_dtype).to(device)
                audios = batch['audio'].type(common_dtype).to(device)

                # AV input
                inputs = {
                    'video': videos,
                    'audio': audios,
                }
                outputs = sp(inputs)
                for idx in range(batch_size):
                    seq = outputs['video'][idx]
                    for frame_idx in range(seq.shape[1]):
                        results['AV']['total'] += 1
                        frame = seq.select(1, frame_idx)
                        result = downstream(frame)
                        if not result is None and ds_interpret[result] == ravdess_interpret[batch['label'].item()]:
                            results['AV']['match'] += 1

                # TODO Save as subtitle?

                # A0 input
                inputs = {
                    'audio': audios,
                }
                outputs = sp(inputs)
                for idx in range(batch_size):
                    seq = outputs['video'][idx]
                    for frame_idx in range(seq.shape[1]):
                        results['A0']['total'] += 1
                        frame = seq.select(1, frame_idx)
                        result = downstream(frame)
                        if not result is None and ds_interpret[result] == ravdess_interpret[batch['label'].item()]:
                            results['A0']['match'] += 1

                # 0V input
                inputs = {
                    'video': videos,
                }
                outputs = sp(inputs)
                for idx in range(batch_size):
                    seq = outputs['video'][idx]
                    for frame_idx in range(seq.shape[1]):
                        results['0V']['total'] += 1
                        frame = seq.select(1, frame_idx)
                        result = downstream(frame)
                        if not result is None and ds_interpret[result] == ravdess_interpret[batch['label'].item()]:
                            results['0V']['match'] += 1

                pbar.update(1)

    from pprint import pprint as pp
    pp(results)

    # Flush stream
    for packet in stream.encode():
        container.mux(packet)
    container.close()



if __name__ == '__main__':
    run(args.input, args.config)
