#!/usr/bin/env python

import argparse
from functools import partial, reduce

output_home = 'output'

#
# Simple parser to get custom train/valid/test splits.
#
class SplitsParser(argparse.Action):

     def __init__(self, option_strings, dest, nargs=None, **kwargs):
         if nargs is not None:
             raise ValueError("nargs not allowed")
         super().__init__(option_strings, dest, **kwargs)

     def __call__(self, parser, namespace, values, option_string=None):
         try:
            _values = tuple(map(lambda s: int(s.strip()), values.split(',')))
            assert len(_values) == 3
         except:
             raise parser.error("Invalid splits format: Need a list of 3 integers separated with commas")
         try:
            assert reduce(lambda x,y: x+y, _values) == 100
         except:
             raise parser.error("The splits must add up to 100%")
         setattr(namespace, self.dest, _values)


parser = argparse.ArgumentParser("Training procedure")
# Tried and did not success (yet) with subparsers.
parser.add_argument("command",
        choices=["init", "run"],
        help=f"Choose a run mode. Please start with `init` for a new training session. Run writes to `{output_home}`.")
parser.add_argument("--config",
        required=False,
        default="training_config.yml",
        help="Path to a specific training configuration file. Default to `training_config.yml` in the current directory.")
parser.add_argument("--batch-size",
        dest='batch_size',
        required=False,
        type=int,
        default=64,
        help="Batch size for training, default to 64.")
parser.add_argument("--epochs",
        required=False,
        type=int,
        default=10,
        help="Number of epochs to train for.")
parser.add_argument("--early-stopping",
        dest="early_stopping",
        required=False,
        type=int,
        default=10,
        help="Number of epochs to wait for before early stopping (no improvement on the validation fold).")
parser.add_argument("--splits",
        required=False,
        action=SplitsParser,
        default=(85,5,10),
        help="Data point splits training/validation/testing. Data points are slices of AV files, not files themselves. Defaults to (85, 5, 10)%% for each. Numbers must add to 100.")
parser.add_argument("--seed",
        required=False,
        type=int,
        default=None,
        help="Seed for PyTorch and Numpy. It will affect the data split allocations.")
parser.add_argument("--force-scratch",
        required=False,
        dest='force_scratch',
        action='store_true',
        help="Force training from scratch, and ignore any existing snapshot.")
parser.add_argument("--with-residuals",
        required=False,
        dest='with_residuals',
        action='store_true',
        help="Activate residual connecitons in the sensoriplexer, useful when training is oddly slow.")
parser.add_argument("--optimization-mode",
        dest="optimization_mode",
        required=False,
        choices=['per-signal', 'global'],
        default='per-signal',
        help="Optimizer application mode: 'per-signal' creates and applies one optimizer per signal, with the traverse in common; 'global' applies a single optimizer over all parameters. Default to per-signal."),
parser.add_argument("--artifacts-frequency",
        required=False,
        dest='artifacts_frequency',
        type=int,
        default=0,
        help="Frequency to generate intermediary files and extra trace information. E.g. to get files every 10 batches, set to 10. Default to 0, no file."),
parser.add_argument("--test-only",
        required=False,
        dest='test_only',
        action='store_true',
        help="Test only. Skip training and validation. This option ignores other training-related options. Please consider setting `--seed` to replay a previous test session.")
parser.add_argument("--test-mode",
        required=False,
        dest='test_mode',
        default='comprehensive',
        choices=['comprehensive', 'centered', 'last'],
        help="Test mode: Comprehensive compares over complete sequences; centered only compares the middle value of each sequence; last only compares the last value of each sequence. Default to comprehensive.")
args = parser.parse_args()


from dataclasses import dataclass
import importlib

import torch
@dataclass
class RuntimeConfig:
    mode: str
    signals: dict
    model: torch.nn.Module
    loss_fn: torch.nn.Module
    artifacts_frequency: int
    artifacts_dir: str
    epoch: int = 0
    batch: int = 0
    video_base: int = 0
    video_scale: int = 255 # Basically RGB
    audio_base: float = 32768 # WAV
    audio_scale: float = 65536 # WAV


from datetime import datetime, timezone
import os
import yaml

import librosa
import librosa.display
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot


def init(config_path='training_config.yml'):
    import shutil
    shutil.copy(
            os.path.join(os.path.dirname(__file__), 'templates', 'training_config.yml'),
            config_path
            )

# TODO Use it?
def calc_avg_loss(iteration, known_loss, batch_loss):
    return (batch_loss + known_loss * iteration) / (iteration + 1)

def get_loss(rconfig, expected, outputs, optimizers=None):
    function = rconfig.loss_fn
    signals = rconfig.signals
    _loss = 0
    for signal in signals:
        loss = function(outputs[signal], expected[signal])
        if optimizers and signal in optimizers.keys():
            optimizer = optimizers[signal]
            loss.backward(retain_graph=True)
            optimizer.step()
        _loss += loss.item()
    return _loss

def last_snapshot_for(base):
    last = None
    last_ts = None
    for r, ds, fs in os.walk(base):
        if last_ts:
            last = sorted([f for f in fs if f.endswith('.pth')])
            if len(last) > 0:
                last = last[-1]
                last = os.path.join(r, last)
                break
        else:
            if len(ds) > 0:
                last_ts = sorted(ds)[-1]
            else:
                break
    return last


from hashlib import sha1
def config_hash(config:dict) -> str:
    return sha1(str(config).encode('ascii')).hexdigest()


import av
import numpy
from PIL import Image
def save_artifact(data, to):
    container = av.open(to + '.mp4', mode='w')

    if 'video' in data:
        # Video for the first batch entry.
        frames = data['video'].cpu()
        stream = container.add_stream('mpeg4', rate=30)
        stream.height = frames.shape[3]
        stream.width = frames.shape[4]
        stream.pix_fmt = 'yuv420p'
        for i in range(frames.shape[1]):
            frame_ = frames[0].select(1, i).detach().numpy().astype(numpy.uint8).transpose((1, 2, 0))
            Image.fromarray(frame_, mode='RGB').save(f'{to}_{i+1}.jpg')
            frame = av.VideoFrame.from_ndarray(frame_, format='rgb24')
            for packet in stream.encode(frame):
                container.mux(packet)
        # Flush stream
        for packet in stream.encode():
            container.mux(packet)

    if 'audio' in data:
        aframes = data['audio'].cpu()
        #astream = output.add_stream('mp3', rate=48000) # Lazy alignment to RAVDESS
        #for i in range(aframes.shape[1]):
        #    frame_ = aframes[0].select(1, i).detach().numpy().astype(numpy.int16).transpose((1, 2, 0))
        #    f = av.audio.frame.AudioFrame.from_ndarray(frame_, format='s16p', layout='stereo')
        #    f.rate = 48000
        #    for paket in astream.encode(f):
        #        container.mux(packet)

        # Spectrograms on the 1st channel, for the first batch entry only.
        mel_spect = librosa.feature.melspectrogram(y=aframes.detach().numpy()[0,0,:].reshape((-1,)), sr=48000, n_fft=1024, hop_length=100)
        mel_spect = librosa.power_to_db(mel_spect, ref=numpy.max)
        librosa.display.specshow(mel_spect, y_axis='mel', fmax=20000, x_axis='time');
        pyplot.savefig(f"{to}_spectr.jpg")

    container.close()


def unstandardize(rc, values):
    if 'video' in values:
        values['video'] = values['video'] * rc.video_scale - rc.video_base
    if 'audio' in values:
        values['audio'] = values['audio'] * rc.audio_scale - rc.audio_base
    return values

def standardize(rc, values):
    if 'video' in values:
        values['video'] = (values['video'] + rc.video_base) / rc.video_scale
    if 'audio' in values:
        values['audio'] = (values['audio'] + rc.audio_base) / rc.audio_scale
    return values


def apply_under(rc, optimizers, expected, inputs, with_model_output=False):
    if optimizers:
        for o in optimizers:
            optimizers[o].zero_grad()

    if rc.artifacts_frequency and rc.batch % rc.artifacts_frequency == 0:
        save_artifact(inputs,  os.path.join(rc.artifacts_dir, f'{rc.mode}_epoch-{rc.epoch}_batch-{rc.batch}_inputs'))

    inputs = standardize(rc, inputs)
    outputs = rc.model(inputs)
    outputs = unstandardize(rc, outputs)

    loss = get_loss(rc, expected, outputs, optimizers)

    if rc.artifacts_frequency and rc.batch % rc.artifacts_frequency == 0:
        save_artifact(outputs, os.path.join(rc.artifacts_dir, f'{rc.mode}_epoch-{rc.epoch}_batch-{rc.batch}_outputs'))
        save_artifact(inputs, os.path.join(rc.artifacts_dir, f'{rc.mode}_epoch-{rc.epoch}_batch-{rc.batch}_standardized_inputs'))

    if with_model_output:
        return (loss, outputs)
    else:
        return loss


def report_to(path, values:list):
    '''
    Yes, the headers could be part of an init function. Negligible, though.
    '''
    with open(path, 'a') as f:
        if not os.path.exists(path):
            f.write('Epoch, AV Loss, A0 Loss, 0V Loss, Training Loss, Valid AV Loss, Valid A0 Loss, Valid 0V Loss, Validation Loss')
        f.write(','.join([f"{v:.2f}" for v in values]))


def make_optimizer(params):
    return torch.optim.AdamW(
        params,
        lr=0.001, betas=(0.9, 0.999),
        eps=1e-08, weight_decay=0.01,
        amsgrad=False)


def run(config_path='training_config.yml', output_dir=output_home, force_scratch=False, test_only=False, batch_size:int=64):

    import logging
    logging.basicConfig(level=logging.INFO)

    with open(config_path) as f:
        config = yaml.safe_load(f)

    import sys
    sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'src'))
    sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'evaluation'))

    if not args.seed is None:
        torch.manual_seed(args.seed)
        #numpy.random.seed(args.seed) TODO either force a 32-bit seed, or change the RNG in Numpy
    logging.info(f'Torch seed: {torch.initial_seed()}')
    common_dtype = getattr(torch, config['kernel']['type'])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Target model
    from coach.sensoriplexer import Sensoriplexer
    model = Sensoriplexer(config['kernel']['code_size'], dtype=common_dtype, device=device, with_residuals=args.with_residuals)
    for name, specs in config['signals'].items():
        model.add(name, tuple(specs['shape']))
    model.to(device=device, dtype=common_dtype)

    logging.debug(model)
    logging.info(f'Sensoriplexer parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

    # Optimizers
    #   Global: 1 optimizer on all parameters, loss calculated once per signal and adding.
    #   Per signal: Each signal gets an optimizer on the signal's parameters and the traverse (common).
    optimizers = {}
    common_optimizer = make_optimizer(model.parameters()) if args.optimization_mode == 'global' else None
    for name in config['signals'].keys():
        if common_optimizer:
            optimizers[name] = common_optimizer
        else:
            params = [tensor for param, tensor in model.named_parameters() if name in param or 'traverse' in param]
            optimizers[name] = make_optimizer(params)

    # Model and optimizer initialization
    config_sig = config_hash(config)
    save_dir = None
    if test_only or not force_scratch:
        last_snapshot = last_snapshot_for(os.path.join(output_dir, config_sig))
        if last_snapshot:
            save_dir = os.path.dirname(last_snapshot)
            state = torch.load(last_snapshot)
            try:
                model.load_state_dict(state['model'])
            except KeyError:
                # For legacy format. Will go
                model.load_state_dict(state)
            model.eval()
            try:
                for o in optimizers:
                    optimizers[o].load_state_dict(state['optimizers'][o])
            except KeyError:
                # For legacy format. Will go
                pass
            try:
                initial_epoch = state['epoch']
            except KeyError:
                # For legacy format. Will go
                initial_epoch = 1
            logging.info(f"Resuming from snapshot {last_snapshot}, from epoch {initial_epoch}")
        else:
            logging.warning("No snapshot to resume from. Starting from scratch")
    if save_dir is None:
        save_dir = os.path.join(output_dir, config_sig, str(int(datetime.now(timezone.utc).timestamp())))
        state = {}
        initial_epoch = 1
        logging.info(f"Training from scratch, snapshotting to {save_dir}")

    from pathlib import Path
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    # Reporting
    report_path = os.path.join(save_dir, 'report.csv')
    report = partial(report_to, report_path)

    with open(os.path.join(save_dir, 'seed.txt'), 'w') as f:
        f.write(str(torch.initial_seed()))

    if args.artifacts_frequency:
        artifacts_dir = os.path.join(save_dir, 'artifacts')
        Path(artifacts_dir).mkdir(parents=True, exist_ok=True)
    else:
        artifacts_dir = None

    # Training metric
    criterion = torch.nn.MSELoss()

    # Data loading
    # No assuming only one dataset available, with data in one folder.
    dataset_config = list(config['datasets'].keys())[0]
    dataset_loader_name = config['datasets'][dataset_config]['loader']['module']
    dataset_loader_module = importlib.import_module(dataset_loader_name)
    dataset_path = config['datasets'][dataset_config]['paths'][0]
    options = config['datasets'][dataset_config]['loader'].get('options', {})
    dataset = getattr(dataset_loader_module, config['datasets'][dataset_config]['loader']['class'])(dataset_path, **options)

    tn = int(len(dataset) * args.splits[0] / 100)
    vn = int(len(dataset) * args.splits[1] / 100)
    train_ds, valid_ds, test_ds = torch.utils.data.random_split(dataset, (tn, vn, len(dataset) - tn - vn))

    from torch.utils.data import DataLoader
    dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    validation_dataloader = DataLoader(valid_ds, batch_size=batch_size, num_workers=0, drop_last=True)
    testing_dataloader = DataLoader(test_ds, batch_size=batch_size, num_workers=0, drop_last=True)

    from tqdm import tqdm
    save_every = 1000 # batches
    logging.info(f'Snapshots saved every {save_every} batches')

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join('output', 'logs', config_sig, os.path.basename(save_dir)))

    rc = RuntimeConfig(
        mode='training',
        signals=config['signals'],
        model=model,
        loss_fn=criterion,
        artifacts_frequency=args.artifacts_frequency,
        artifacts_dir=artifacts_dir,
        video_base=config['signals']['video']['base'],
        video_scale=config['signals']['video']['scale'],
        audio_base=config['signals']['audio']['base'],
        audio_scale=config['signals']['audio']['scale'],
    )

    if not test_only:
        validation_history = []

        last_epoch = initial_epoch + args.epochs - 1
        for epoch in range(initial_epoch, last_epoch+1):
            running_loss = 0.0
            rc.mode = 'training'
            rc.epoch = epoch
            with tqdm(total=len(dataloader)) as pbar:
                pbar.set_description(f'Epoch {epoch}/{last_epoch}')
                for idx, batch in enumerate(dataloader):
                    rc.batch = idx

                    videos = batch['video'].type(common_dtype).to(device)
                    audios = batch['audio'].type(common_dtype).to(device)

                    expected = {'video': videos, 'audio': audios}

                    apply = partial(apply_under, rc, optimizers, expected)

                    _av_loss = apply({'video': videos, 'audio': audios})
                    _a0_loss = apply({'audio': audios})
                    _0v_loss = apply({'video': videos})

                    _loss = _av_loss + _a0_loss + _0v_loss
                    running_loss = (_loss + running_loss * idx) / (idx + 1)

                    if idx % save_every == (save_every-1):
                        state = {
                                'epoch': epoch,
                                'model': model.state_dict(),
                                'optimizers': { name: optim.state_dict() for name, optim in optimizers.items() },
                                }
                        batch_block = str(idx+1).zfill(6)
                        epoch_block = str(epoch).zfill(3)
                        torch.save(state, os.path.join(save_dir, f'sensoriplexer_{epoch_block}_{batch_block}.pth'))

                    pbar.set_postfix(loss=running_loss, xv_loss=_0v_loss, av_loss=_av_loss, ax_loss=_a0_loss)
                    pbar.update(1)

                    if idx % 100 == 99 or idx == len(validation_dataloader) - 1:
                        writer.add_scalar('avloss/train', _av_loss, len(dataloader)*(epoch-1) + idx)
                        writer.add_scalar('a0loss/train', _a0_loss, len(dataloader)*(epoch-1) + idx)
                        writer.add_scalar('0vloss/train', _0v_loss, len(dataloader)*(epoch-1) + idx)
                        writer.add_scalar('loss/train', running_loss, len(dataloader)*(epoch-1) + idx)

            state = {
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'optimizer': { name: optim.state_dict() for name, optim in optimizers.items() },
                    }
            epoch_block = str(epoch).zfill(3)
            torch.save(state, os.path.join(save_dir, f'sensoriplexer_{epoch_block}_complete.pth'))

            v_av_loss, v_a0_loss, v_0v_loss, validation_loss = (float('inf'),) * 4
            with torch.no_grad():
                try:
                    rc.mode = 'validation'
                    validation_loss = 0.0
                    with tqdm(total=len(validation_dataloader)) as pbar:
                        pbar.set_description(f'Validation on epoch {epoch}')
                        for idx, batch in enumerate(validation_dataloader):
                            rc.batch = idx

                            videos = batch['video'].type(common_dtype).to(device)
                            audios = batch['audio'].type(common_dtype).to(device)

                            expected = {'video': videos, 'audio': audios}

                            apply = partial(apply_under, rc, None, expected)

                            v_av_loss = apply({'video': videos, 'audio': audios})
                            v_a0_loss = apply({'audio': audios})
                            v_0v_loss = apply({'video': videos})

                            _loss = v_av_loss + v_a0_loss + v_0v_loss
                            validation_loss = (_loss + validation_loss * idx) / (idx + 1)

                            pbar.set_postfix(loss=validation_loss)
                            pbar.update(1)

                            if idx % 100 == 99 or idx == len(validation_dataloader) - 1:
                                writer.add_scalar('avloss/valid', v_av_loss, len(validation_dataloader)*(epoch-1) + idx)
                                writer.add_scalar('a0loss/valid', v_a0_loss, len(validation_dataloader)*(epoch-1) + idx)
                                writer.add_scalar('0vloss/valid', v_0v_loss, len(validation_dataloader)*(epoch-1) + idx)
                                writer.add_scalar('loss/valid', validation_loss, len(validation_dataloader)*(epoch-1) + idx)

                    report([epoch, _av_loss, _a0_loss, _0v_loss, running_loss, v_av_loss, v_a0_loss, v_0v_loss, validation_loss])

                    validation_history.append(validation_loss)
                    if len(validation_history) > args.early_stopping and reduce(lambda x,y: x and y, [x <= validation_loss for x in validation_history[-args.early_stopping:]]):
                        raise StopIteration()
                except StopIteration:
                    logging.info(f"Early stopping after validation loss not decreasing for {args.early_stopping} epochs.")
                    break



    # Downstream model
    ds_signal = list(config['downstream'].keys())[0]
    ds_specs = list(config['downstream'][ds_signal][0].values())[0]
    logging.info(f"Downstream system on the {ds_signal} signal: {ds_specs['description']}")
    ds_path = os.path.dirname(ds_specs['handler'])
    ds_module = os.path.basename(ds_specs['handler'])
    sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), ds_path))
    ds_handle = importlib.import_module(ds_module)
    ds = ds_handle.specs()
    downstream = partial(ds_handle.run, ds['model'])
    ds_interpret = ds['interpretation']
    if 'preprocessing' in ds:
        ds_preprocess = ds['preprocessing']
    else:
        ds_preprocess = None

    if 'extra' in config['datasets'][dataset_config] and 'interpretation' in config['datasets'][dataset_config]['extra']:
        dataset_interpret = config['datasets'][dataset_config]['extra']['interpretation']
    elif hasattr(dataset, 'label_map'):
        dataset_interpret = dataset.label_map

    with torch.no_grad():
        rc.mode = 'testing'
        testing_loss = 0.0
        accuracy_results = {
            'AV': 0,
            'A0': 0,
            '0V': 0,
            'DS': 0,
        }
        accuracy_frames = 0
        debug_dir = './debug'
        Path(debug_dir).mkdir(parents=True, exist_ok=True)
        with tqdm(total=len(testing_dataloader)) as pbar:
            pbar.set_description(f'Testing best model')
            for idx, batch in enumerate(testing_dataloader):
                rc.batch = idx

                videos = batch['video'].type(common_dtype).to(device)
                audios = batch['audio'].type(common_dtype).to(device)
                labels = batch['label']

                expected = {'video': videos, 'audio': audios}

                apply = partial(apply_under, rc, None, expected)

                t_av_loss, t_av_out = apply({'video': videos, 'audio': audios}, with_model_output=True)
                t_a0_loss, t_a0_out = apply({'audio': audios}, with_model_output=True)
                t_0v_loss, t_0v_out = apply({'video': videos}, with_model_output=True)

                _loss = t_av_loss + t_a0_loss + t_0v_loss
                testing_loss = (_loss + testing_loss * idx) / (idx + 1)

                pbar.set_postfix(loss=testing_loss)
                pbar.update(1)

                if idx % 100 == 99 or idx == len(testing_dataloader) - 1:
                    writer.add_scalar('avloss/test', t_av_loss, idx)
                    writer.add_scalar('a0loss/test', t_a0_loss, idx)
                    writer.add_scalar('0vloss/test', t_0v_loss, idx)
                    writer.add_scalar('loss/test', testing_loss, idx)

                batch_accuracy_results = { 'AV': 0, 'A0': 0, '0V': 0, 'DS': 0 }
                views = { 'AV': [], 'A0': [], '0V': [], 'DS': [] }
                confusion = { 'AV': [], 'A0': [], '0V': [], 'DS': [] }
                confusion_gt = []
                for case, ds_case in {'DS': videos, 'AV': t_av_out['video'], 'A0': t_a0_out['video'], '0V': t_0v_out['video']}.items():
                    for bidx in range(batch_size):
                        seq = ds_case[bidx]
                        result = None
                        if args.test_mode == 'comprehensive':
                            indices = range(seq.shape[1])
                            if case == 'DS':
                                accuracy_frames += seq.shape[1] # Count once the number of frames submitted to all varieties.
                        elif args.test_mode == 'centered':
                            indices = [seq.shape[1] // 2]
                            if case == 'DS':
                                accuracy_frames += 1 # Count once the number of frames submitted to all varieties.
                        elif args.test_mode == 'last':
                            indices = [seq.shape[1] - 1]
                            if case == 'DS':
                                accuracy_frames += 1 # Count once the number of frames submitted to all varieties.
                        for frame_idx in indices:
                            frame = seq.select(1, frame_idx)
                            if ds_preprocess:
                                dsinput = ds_preprocess(frame)
                            else:
                                dsinput = frame
                            result = downstream(dsinput)
                            try:
                                label_bidx = labels[bidx].item()
                            except:
                                # TODO Lazy workaround. Need to define the label shape requirements from the config file.
                                label_bidx = labels[bidx][0].item()

                            # Track accuracy.
                            if not result is None and ds_interpret[result] == dataset_interpret[label_bidx]:
                                accuracy_results[case] += 1
                                batch_accuracy_results[case] += 1

                            # Data for confusion matrices.
                            if not result is None:
                                confusion[case].append(ds_interpret[result])
                            else:
                                confusion[case].append('error')
                            if case == 'DS':
                                confusion_gt.append(dataset_interpret[label_bidx])

                            # Keep sample data for reporting.
                            if frame_idx == 0 or args.test_mode != 'comprehensive':
                                views[case].append({
                                    'frame': frame.cpu().numpy().astype(numpy.uint8).transpose((1,2,0)),
                                    'gt': dataset_interpret[label_bidx],
                                    'result': ds_interpret[result],
                                })

                # Plot reporting data, if demanded.
                if rc.artifacts_frequency > 0: # TODO Here are debug data, different from the artifacts. Here is a lazy switch, possibly confusing.
                    for bidx in range(batch_size):
                        _, axes = pyplot.subplots(1, 4)
                        for idx, case in enumerate(['DS', '0V', 'AV', 'A0']): # Order we want to see in reports.
                            axes[idx].imshow(views[case][bidx]['frame'])
                            axes[idx].set_title(f"{case}\nGT={views[case][bidx]['gt']}\nR={views[case][bidx]['result']}")
                            axes[idx].axis('off')
                        pyplot.savefig(os.path.join(debug_dir, f"batch={rc.batch}#{bidx+1}.jpg"), bbox_inches='tight')
                        pyplot.close()

            # Plot confusion matrices anyway
            confusion_labels = [''] + list(set(confusion_gt)) + ['error']
            from sklearn.metrics import confusion_matrix
            for case, predictions in confusion.items():
                cm = confusion_matrix(confusion_gt, predictions)
                fig, ax = pyplot.subplots(figsize=(7.5, 7.5))
                ax.matshow(cm, cmap=pyplot.cm.Reds, alpha=0.3)
                ax.set_xticklabels(confusion_labels)
                ax.set_yticklabels(confusion_labels)
                for i in range(cm.shape[0]):
                    for j in range(cm.shape[1]):
                        ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large')
                pyplot.xlabel(case, fontsize=18)
                pyplot.ylabel('GT', fontsize=18)
                #pyplot.title('Confusion Matrix', fontsize=18)
                pyplot.savefig(os.path.join(save_dir, f"results_cm_{case.lower()}.jpg"))

            print('-'*80)
            print(f"acc/plain={accuracy_results['DS'] / accuracy_frames * 100.0}")
            print(f"acc/0v={accuracy_results['0V'] / accuracy_frames * 100.0}")
            print(f"acc/av={accuracy_results['AV'] / accuracy_frames * 100.0}")
            print(f"acc/a0={accuracy_results['A0'] / accuracy_frames * 100.0}")

            print('-'*80)
            traverse = [t for name, t in rc.model.named_modules() if 'traverse' in name][0]
            for idx, signal in enumerate(rc.signals): # Type of hazardous code, where I rely perhaps too much on Python 3.7 and above dict key ordering promise.
                target_block = traverse.diagonal_block(idx).cpu().numpy()
                for error in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6]:
                    print(f"\t- Zero hypothesis at {error}: Max={abs(target_block.max()) < error} [{target_block.max():.2f}], Mean={target_block.ravel().mean() < error} [{target_block.ravel().mean():.2f}], Max={abs(target_block.min()) < error} [{target_block.min():.2f}]")
                print('-'*80)


if args.command == "init":
    init(args.config)
else:
    run(config_path=args.config, force_scratch=args.force_scratch, test_only=args.test_only, batch_size=args.batch_size)
