# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from syne_tune.backend.trial_status import Trial, Status
from syne_tune.tuning_status import TuningStatus, print_best_metric_found


def test_status():
    metric_names = ['NLL', 'time']
    status = TuningStatus(metric_names=metric_names)

    trial0 = Trial(trial_id=0, config={"x": 1.0}, creation_time=None)
    trial1 = Trial(trial_id=1, config={"x": 5.0}, creation_time=None)
    status.update(
        trial_status_dict={
            0: (trial0, Status.in_progress),
            1: (trial1, Status.in_progress),
        },
        new_results=[
            (0, {"NLL": 2.0, "time": 10.0}),
            (0, {"NLL": 1.0, "time": 12.0 }),
            (1, {"NLL": 3.0, "time": 5.0}),
        ]
    )
    assert status.overall_metric_statistics.max_metrics
    assert status.num_trials_started == 2
    assert status.overall_metric_statistics.max_metrics == {'NLL': 3.0, 'time': 12.0}
    assert status.overall_metric_statistics.min_metrics == {'NLL': 1.0, 'time': 5.0}
    assert status.overall_metric_statistics.sum_metrics == {'NLL': 6.0, 'time': 27.0}

    assert status.trial_metric_statistics[0].max_metrics == {'NLL': 2.0, 'time': 12.0}
    assert status.trial_metric_statistics[0].min_metrics == {'NLL': 1.0, 'time': 10.0}
    assert status.trial_metric_statistics[0].sum_metrics == {'NLL': 3.0, 'time': 22.0}

    status.update(
        trial_status_dict={
            0: (trial0, Status.in_progress),
        },
        new_results=[
            (0, {"NLL": 0.0, "time": 20.0}),
        ]
    )
    assert status.trial_metric_statistics[0].max_metrics == {'NLL': 2.0, 'time': 20.0}
    assert status.trial_metric_statistics[0].min_metrics == {'NLL': 0.0, 'time': 10.0}
    assert status.trial_metric_statistics[0].sum_metrics == {'NLL': 3.0, 'time': 42.0}
    assert status.trial_metric_statistics[0].last_metrics == {'NLL': 0.0, 'time': 20.0}

    print(str(status))

    best_trialid, best_metric = print_best_metric_found(
        tuning_status=status,
        metric_names=metric_names,
        mode='min',
    )
    assert best_trialid == 0
    assert best_metric == 0.0

    best_trialid, best_metric = print_best_metric_found(
        tuning_status=status,
        metric_names=metric_names,
        mode='max',
    )
    assert best_trialid == 1
    assert best_metric == 3.0