R"""Script for running mGLS on binary classification tasks.

cd ~/Desktop/projects/zonotopic_relu
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/zonotopic_relu


python3 scripts/classification/mgls.py \
    --outdir="/tmp" \
    --configs_path="exps.classification.mgls_configs.CONFIGS" \
    --config="fast" \
    --n_runs=2 \
    --n_processes=10 \
    --save_frequency=10

"""
import dataclasses
import json
import os
from pydoc import locate

from absl import app
from absl import flags
from absl import logging

if True:
    # For whatever reason, it looks like using the GPU is slower for these
    # small models and/or style of feeding data.
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np

from xoid.datasets import classification

from xoid.gls import base_mgls
from xoid.gls import multiprocess_mgls


FLAGS = flags.FLAGS

_CONFIGS_PATH = "exps.classification.mgls_configs.CONFIGS"

if __name__ == "__main__":
    # Directory should already exist.
    flags.DEFINE_string('outdir', None, 'Path directory to create where we will write output to.')

    flags.DEFINE_string('configs_path', _CONFIGS_PATH, 'Python path to configs dict.')
    flags.DEFINE_string('config', None, 'Name of the entry in the configs dict to use as configuration.')

    flags.DEFINE_integer('n_runs', 1, 'Number of times to repeat the experiment.')

    flags.DEFINE_integer('n_processes', 1, '')
    flags.DEFINE_integer('save_frequency', None, '')

    flags.mark_flags_as_required(['outdir', 'configs_path', 'config'])


@dataclasses.dataclass()
class Config:
    name: str

    dataset: str
    n_components: int

    N: int    
    m: int

    max_steps: int

    random_subset: bool = True

    eps: float = 1e-7

    max_level_set_attempts: int = 0


def get_dataset(cfg):
    (X, Y), val_data, _ = classification.make_dataset(cfg.dataset, cfg.n_components)

    if cfg.random_subset:
        inds = np.arange(X.shape[0], dtype=np.int32)
        np.random.shuffle(inds)
        inds = inds[:cfg.N]
        X = X[inds]
        Y = Y[inds]
    else:
        X = X[:cfg.N]
        Y = Y[:cfg.N]

    return (X, Y), val_data


def call(model_params, X):
    p = model_params
    acts = np.maximum(X @ p.w + p.b, 0)
    return acts @ np.squeeze(p.v) + np.squeeze(p.c)


def compute_accuracy(preds, Y):
    corrects_mask = (preds > 0) == Y
    return float(corrects_mask.astype(np.int32).sum()) / float(Y.shape[0])


def do_run(cfg, run_index):
    results = {
        'config': dataclasses.asdict(cfg),
        #
        # The following are parallel arrays listing metrics
        # per training step.
        'train_losses': [],
        'train_accuracies': [],
        'val_accuracies': [],
    }

    def save_results():
        filepath = os.path.join(FLAGS.outdir, f'gls_{cfg.name}.{run_index}.json')
        filepath = os.path.expanduser(filepath)
        with open(filepath, 'w') as f:
            json.dump(results, f)

    (X, Y), (X_val, Y_val) = get_dataset(cfg)

    mgls_options = multiprocess_mgls.MultiprocessMglsOptions(
        loss_fn='sigmoid_cross_entropy',
        output_weights_style='pm_1',
        regularization=None,
        max_level_set_attempts=cfg.max_level_set_attempts,
        no_scs=True,
        n_processes=FLAGS.n_processes,
    )

    mgls = multiprocess_mgls.MultiprocessMgls(
        X, Y, cfg.m, mgls_options, eps=cfg.eps
    )

    try:
        for i in mgls.solve_iter(cfg.max_steps):
            res = mgls.get_current_vertex_results()
            if res is None:
                break

            if FLAGS.save_frequency is not None and not (i + 1) % FLAGS.save_frequency:
                save_results()

            preds = call(res.model_params, X)
            acc = compute_accuracy(preds, Y)

            preds_val = call(res.model_params, X_val)
            acc_val = compute_accuracy(preds_val, Y_val)

            results['train_losses'].append(res.loss)
            results['train_accuracies'].append(acc)
            results['val_accuracies'].append(acc_val)
            print(f'{i}: {acc:.4f} {acc_val:.4f} {res.loss}')

    except base_mgls.LevelSetExhaustedError:
        # Intentionally do nothing here.
        pass

    save_results()

    mgls.shutdown_subprocesses()


def main(_):
    cfg = locate(FLAGS.configs_path)[FLAGS.config]
    for i in range(FLAGS.n_runs):
        do_run(cfg, i)


if __name__ == "__main__":
    app.run(main)
