r"""
Evaluate the effect of a traditional regularizer (weight-decay)
on the approximated posteriors.
"""

import argparse
import awflow as aw
import numpy as np
import os
import papermill as pm
import torch

from ratio_estimation import estimate_coverage
from ratio_estimation import load_estimator

from hypothesis.benchmark.tractable_small import Prior
from hypothesis.benchmark.tractable_small import Simulator



# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--partition', type=str, default=None, help='Slurm partition to execute the pipeline on (default: none).')
parser.add_argument('--redo', action='store_true', help='Remove the generating files, pipeline will be re-executed (default: false).')
arguments, _ = parser.parse_known_args()
if arguments.redo:
    os.system('rm -rf data')
    os.system('rm -rf output')


# Script parameters
num_estimators = 5
lambdas = ['0.0', '0.001', '0.01', '0.1', '0.2', '0.25', '0.5', '1.0']
confidence_levels = np.linspace(0.05, 0.95, 19)


# Utilities ####################################################################

@torch.no_grad()
def simulate(outputdir, budget):
    simulator = Simulator()
    prior = Prior()
    inputs = prior.sample((budget,))
    outputs = simulator(inputs)
    np.save(outputdir + '/inputs.npy', inputs.float().numpy())
    np.save(outputdir + '/outputs.npy', outputs.float().numpy())


def optimize(outputdir, weight_decay):
    command = r"""python -m hypothesis.bin.ratio_estimation.train --batch-size 128 \
                    --criterion hypothesis.nn.ratio_estimation.BaseCriterion \
                    --data-test ratio_estimation.DatasetJointTest \
                    --data-train ratio_estimation.DatasetJointTrain \
                    --data-validate ratio_estimation.DatasetJointValidate \
                    --epochs 100 \
                    --estimator ratio_estimation.RatioEstimator \
                    --lr 0.001 \
                    --lrsched-on-plateau \
                    --weight-decay {wd} \
                    --out {out}""".format(
        wd=weight_decay,
        out=outputdir)
    os.system(command)


# Workflow definition ##########################################################


@aw.timelimit('1:00:00')
@aw.cpus_and_memory(2, "4GB")
@aw.postcondition(aw.exists('data/train/inputs.npy'))
@aw.postcondition(aw.exists('data/train/outputs.npy'))
@aw.postcondition(aw.exists('data/validate/inputs.npy'))
@aw.postcondition(aw.exists('data/validate/outputs.npy'))
@aw.postcondition(aw.exists('data/test/inputs.npy'))
@aw.postcondition(aw.exists('data/test/outputs.npy'))
@aw.postcondition(aw.exists('data/coverage/inputs.npy'))
@aw.postcondition(aw.exists('data/coverage/outputs.npy'))
def generate_simulations():
    simulation_budget = 2 ** 14
    # Simulate the datasets
    os.makedirs('data/train', exist_ok=True)
    simulate('data/train', simulation_budget)
    os.makedirs('data/validate', exist_ok=True)
    simulate('data/validate', simulation_budget)
    os.makedirs('data/test', exist_ok=True)
    simulate('data/test', simulation_budget)
    os.makedirs('data/coverage', exist_ok=True)
    simulate('data/coverage', 10000)


def train_and_evaluate(regularizer_strength):
    root = 'output/estimator/' + regularizer_strength
    os.makedirs(root, exist_ok=True)

    # Train the ratio estimator
    @aw.cpus_and_memory(4, "8GB")
    @aw.dependency(generate_simulations)
    @aw.postcondition(aw.num_files(root + '/*/weights.th', num_estimators))
    @aw.tasks(num_estimators)
    @aw.timelimit('12:00:00')
    def train(task_index):
        outputdir = root + '/' + str(task_index)
        os.makedirs(outputdir, exist_ok=True)
        if not os.path.exists(outputdir + '/weights.th'):
            optimize(outputdir, regularizer_strength)

    # Evaluate its expected coverage
    @aw.cpus_and_memory(4, "4GB")
    @aw.dependency(train)
    @aw.postcondition(aw.num_files(root + '/*/coverage.npy', num_estimators))
    @aw.tasks(num_estimators)
    @aw.timelimit('12:00:00')
    def coverage(task_index):
        outputdir = root + '/' + str(task_index)
        if not os.path.exists(outputdir + '/coverage.npy'):
            inputs = torch.from_numpy(np.load('data/coverage/inputs.npy'))
            outputs = torch.from_numpy(np.load('data/coverage/outputs.npy'))
            ratio_estimator = load_estimator(outputdir + '/weights.th')
            coverages = estimate_coverage(ratio_estimator, inputs, outputs, confidence_levels)
            np.save(outputdir + '/coverage.npy', coverages)

    return coverage


dependencies = []
for regularizer_strength in lambdas:
    dependencies.append(train_and_evaluate(regularizer_strength))


@aw.cpus_and_memory(4, "8GB")
@aw.dependency(dependencies)
@aw.postcondition(aw.exists('plots/weight_decay_estimators.pdf'))
@aw.timelimit('1:00:00')
def summarize():
    pm.execute_notebook('summary.ipynb', 'summary.ipynb')


if __name__ == '__main__':
    aw.execute(partition=arguments.partition)
