R"""Script for running fixed vertex optimization on binary classification tasks.

cd ~/Desktop/projects/zonotopic_relu
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/zonotopic_relu


python3 scripts/classification/random_vertex.py \
    --outdir="/tmp" \
    --configs_path="exps.classification.random_vertex_configs.CONFIGS" \
    --config="fast" \
    --n_runs=2

"""
import dataclasses
import json
import os
from pydoc import locate
import time

from absl import app
from absl import flags
from absl import logging

if True:
    # For whatever reason, it looks like using the GPU is slower for these
    # small models and/or style of feeding data.
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np

from xoid.datasets import classification
from xoid.solvers import vertex_solvers

from xoid.util import misc_util
from xoid.util import vertex_util


FLAGS = flags.FLAGS

_CONFIGS_PATH = "exps.classification.random_vertex_configs.CONFIGS"

if __name__ == "__main__":
    # Directory should already exist.
    flags.DEFINE_string('outdir', None, 'Path directory to create where we will write output to.')

    flags.DEFINE_string('configs_path', _CONFIGS_PATH, 'Python path to configs dict.')
    flags.DEFINE_string('config', None, 'Name of the entry in the configs dict to use as configuration.')

    flags.DEFINE_integer('n_runs', 1, 'Number of times to repeat the experiment.')

    flags.mark_flags_as_required(['outdir', 'configs_path', 'config'])


@dataclasses.dataclass()
class Config:
    name: str

    dataset: str
    n_components: int

    N: int    
    m: int

    random_subset: bool = True

    eps: float = 1e-7

    def __post_init__(self):
        pass


def get_dataset(cfg):
    (X, Y), val_data, _ = classification.make_dataset(cfg.dataset, cfg.n_components)

    if cfg.random_subset:
        inds = np.arange(X.shape[0], dtype=np.int32)
        np.random.shuffle(inds)
        inds = inds[:cfg.N]
        X = X[inds]
        Y = Y[inds]
    else:
        X = X[:cfg.N]
        Y = Y[:cfg.N]

    return (X, Y), val_data


def call(model_params, X):
    p = model_params
    acts = np.maximum(X @ p.w + p.b, 0)
    return acts @ np.squeeze(p.v) + np.squeeze(p.c)


def compute_accuracy(preds, Y):
    corrects_mask = (preds > 0) == Y
    return float(corrects_mask.astype(np.int32).sum()) / float(Y.shape[0])


def do_run(cfg, run_index: int):
    (X, Y), (X_val, Y_val) = get_dataset(cfg)

    v = misc_util.make_pm_1_v(cfg.m, X.dtype)

    solver = vertex_solvers.VertexSolver(
        X, Y, loss_fn='sigmoid_cross_entropy', m=cfg.m, v=v, eps=cfg.eps)

    vertex = vertex_util.random_vertex(X, cfg.m)

    res = solver.solve(vertex)

    preds = call(res.model_params, X)
    acc = compute_accuracy(preds, Y)

    preds_val = call(res.model_params, X_val)
    acc_val = compute_accuracy(preds_val, Y_val)

    print(f'{acc:.4f} {acc_val:.4f} {res.loss}')

    results = {
        'train_loss': res.loss,
        'train_acc': acc,
        'val_acc': acc_val,
        'config': dataclasses.asdict(cfg),
    }

    filepath = os.path.join(FLAGS.outdir, f'rv_{cfg.name}.{run_index}.json')
    filepath = os.path.expanduser(filepath)
    with open(filepath, 'w') as f:
        json.dump(results, f)


def main(_):
    cfg = locate(FLAGS.configs_path)[FLAGS.config]
    for i in range(FLAGS.n_runs):
        start = time.time()
        do_run(cfg, i)
        print(f'Time taken: {time.time() - start} seconds')


if __name__ == "__main__":
    app.run(main)
