#
# Bayesian Optimization of Combinatorial Structures
#
# Copyright (C) 2018 R. Baptista & M. Poloczek
# 
# BOCS is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# BOCS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License 
# along with BOCS.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright (C) 2018 MIT & University of Arizona
# Authors: Ricardo Baptista & Matthias Poloczek
# E-mails: rsb@mit.edu & poloczek@email.arizona.edu
#

import numpy as np
import matplotlib.pyplot as plt
from localglobal.boCS import BOCS
from sample_models import sample_models
import pickle
from localglobal.test_funcs.contamination import Contamination
from localglobal.test_funcs.MaxSAT.maximum_satisfiability import MaxSAT60
import torch

CONTAMINATION_N_STAGES = 25

if __name__ == '__main__':
    # Save inputs in dictionary
    num_repeats = 5
    runs = []
    inputs = {'n_init': 10, 'lambda': 1e-4}

    # obj_fn = 'contamination_0'
    # obj_fn = 'contamination_0.01'
    obj_fn = 'maxsat60'

    type = 'SA'
    # type = 'SDP-l1'

    print(f'running problem={obj_fn}, type=BOCS-{type}')

    if obj_fn == 'maxsat60':
        maxsat_ = MaxSAT60()
        inputs['n_vars'] = maxsat_.n_variables
        inputs['evalBudget'] = 250
        inputs['model'] = lambda x: maxsat_.evaluate(torch.from_numpy(x))
    elif obj_fn == 'contamination_0.01':
        inputs['n_vars'] = CONTAMINATION_N_STAGES
        inputs['evalBudget'] = 150
        a = Contamination(lamda=1e-2, normalize=False)
        inputs['model'] = lambda x: a.compute(x)
    elif obj_fn == 'contamination_0':
        inputs['n_vars'] = CONTAMINATION_N_STAGES
        inputs['evalBudget'] = 150
        a = Contamination(lamda=0, normalize=False)
        inputs['model'] = lambda x: a.compute(x)

    # inputs['penalty'] = lambda x: inputs['lambda'] * np.sum(x, axis=1)
    inputs['penalty'] = lambda x: 0

    for i in range(num_repeats):
        print(f'Starting repeat {i+1}/{num_repeats}:')
        # Generate initial samples for statistical models
        inputs['x_vals'] = sample_models(inputs['n_init'], inputs['n_vars'])
        inputs['y_vals'] = []
        for xv in inputs['x_vals']:
            inputs['y_vals'].append(inputs['model'](xv))
        inputs['y_vals'] = np.array(inputs['y_vals'])

        # inputs['y_vals'] = inputs['model'](inputs['x_vals'])

        # Run BOCS-SA and BOCS-SDP (order 2)
        (_, BOCS_obj) = BOCS(inputs.copy(), 2, type)

        # Compute optimal value found by BOCS
        iter_t = np.arange(BOCS_obj.size)
        BOCS_opt = np.minimum.accumulate(BOCS_obj)

        runs.append(BOCS_opt)
        print(BOCS_opt)

    runs = np.array(runs)

    file = open(obj_fn + f'_baseline_result_bocs-{type}_fullruns.pkl', 'wb')
    pickle.dump({'runs': runs}, file)
    file.close()
