# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import time
from collections import defaultdict, OrderedDict
from typing import List, Dict, Tuple

import pandas as pd

from syne_tune.backend.trial_status import Status, Trial
import numpy as np

from syne_tune.constants import ST_WORKER_TIME, ST_WORKER_COST


class MetricsStatistics:
    def __init__(self):
        """
        Allows to maintain simple running statistics (min/max/sum/count) of metrics provided.
        :param metric_names: metrics to be tracked, if not passed all metrics seen in the first report are used.
        """
        self.metric_names = []
        self.count = 0
        self.min_metrics = {}
        self.max_metrics = {}
        self.sum_metrics = {}
        self.last_metrics = {}

    def add(self, metrics: Dict):
        for name in metrics.keys():
            if name not in self.metric_names:
                self.metric_names.append(name)
                self.min_metrics[name] = np.inf
                self.max_metrics[name] = -np.inf
                self.sum_metrics[name] = 0

        self.min_metrics = {m: min(metrics[m], current) for m, current in self.min_metrics.items()}
        self.max_metrics = {m: max(metrics[m], current) for m, current in self.max_metrics.items()}
        self.sum_metrics = {m: metrics[m] + current for m, current in self.sum_metrics.items()}
        self.last_metrics = metrics
        self.count += 1


class TuningStatus:
    """
    Information of a tuning job to display as progress or to use to decide whether to stop the tuning job.
    """
    def __init__(self, metric_names: List[str]):
        self.metric_names = metric_names
        self.start_time = time.perf_counter()

        self.overall_metric_statistics = MetricsStatistics()
        self.trial_metric_statistics = defaultdict(lambda: MetricsStatistics())

        self.last_trial_status_seen = OrderedDict()
        self.trial_rows = OrderedDict({})

    def update(self, trial_status_dict: Dict[int, Tuple[Trial, str]], new_results: List[Tuple[int, Dict]]):
        """
        Updates the tuning status given new statuses and results.
        """

        self.last_trial_status_seen.update(
            {k: v[1] for k, v in trial_status_dict.items()})

        for trial_id, new_result in new_results:
            self.overall_metric_statistics.add(new_result)
            self.trial_metric_statistics[trial_id].add(new_result)

        for trial_id, (trial, status) in trial_status_dict.items():
            num_metrics = self.trial_metric_statistics[trial_id].count
            row = {
                "trial_id": trial_id,
                "status": status,
                "iter": num_metrics,
            }
            row.update(trial.config)
            row.update(self.trial_metric_statistics[trial_id].last_metrics)

            if ST_WORKER_TIME in self.trial_metric_statistics[trial_id].max_metrics:
                row["worker-time"] = self.trial_metric_statistics[trial_id].max_metrics[ST_WORKER_TIME]
            if ST_WORKER_COST in self.trial_metric_statistics[trial_id].max_metrics:
                row["worker-cost"] = self.trial_metric_statistics[trial_id].max_metrics[ST_WORKER_COST]

            self.trial_rows[trial_id] = row

    def mark_running_job_as_stopped(self):
        """
        Update the status of all trials still running to be marked as stop.
        """
        self.last_trial_status_seen = {
            k: v if v != Status.in_progress else Status.stopped
            for k, v in self.last_trial_status_seen.items()
        }
        for trial_id, row in self.trial_rows.items():
            if row["status"] == Status.in_progress:
                row["status"] = Status.stopped

    @property
    def num_trials_started(self):
        return len(self.last_trial_status_seen)

    def _num_trials(self, status: str):
        return sum(trial_status == status for trial_status in self.last_trial_status_seen.values())

    @property
    def num_trials_completed(self):
        return self._num_trials(status=Status.completed)

    @property
    def num_trials_failed(self):
        return self._num_trials(status=Status.failed)

    @property
    def num_trials_finished(self):
        """
        :return: number of trials that finished, e.g. that completed, were stopped or are stopping, or failed
        """
        # note it may be inefficient to query several times the dataframe in case a very large number of jobs are
        #  present, we could query the dataframe only once
        return self._num_trials(status=Status.completed) + self._num_trials(status=Status.stopped) + \
               self._num_trials(status=Status.stopping) + self._num_trials(status=Status.failed)

    @property
    def num_trials_running(self):
        return self._num_trials(status=Status.in_progress)

    @property
    def wallclock_time(self):
        """
        :return: the wallclock time spent in the tuner
        """
        return time.perf_counter() - self.start_time

    @property
    def user_time(self):
        """
        :return: the total user time spent in the workers
        """
        if ST_WORKER_TIME in self.overall_metric_statistics.metric_names:
            usertime_per_trial = [
                metric.max_metrics.get(ST_WORKER_TIME, 0)
                for trial, metric in self.trial_metric_statistics.items()
            ]
            return sum(usertime_per_trial)
        else:
            return 0

    @property
    def cost(self):
        """
        :return: the estimated dollar-cost spent while tuning
        """
        if ST_WORKER_COST in self.overall_metric_statistics.metric_names:
            cost_per_trial = [
                metric.max_metrics.get(ST_WORKER_COST, 0)
                for trial, metric in self.trial_metric_statistics.items()
            ]
            return sum(cost_per_trial)
        else:
            return 0.0

    def get_dataframe(self) -> pd.DataFrame:
        return pd.DataFrame(self.trial_rows.values())

    def __str__(self):
        num_running = self.num_trials_running
        num_finished = self.num_trials_started - num_running

        if len(self.trial_rows) > 0:
            df = self.get_dataframe()
            cols = [col for col in df.columns if not col.startswith("st_")]
            res_str = df.loc[:, cols].to_string(index=False, na_rep="-") + "\n"
        else:
            res_str = ""
        res_str += \
               f"{num_running} trials running, " \
               f"{num_finished} finished ({self.num_trials_completed} until the end), " \
               f"{self.wallclock_time:.2f}s wallclock-time"
               # f"{self.user_time:.2f}s approximated user-time"
        cost = self.cost
        if cost is not None and cost > 0.0:
            res_str += f", ${cost:.2f} estimated cost"
        res_str += "\n"
        return res_str


def print_best_metric_found(
        tuning_status: TuningStatus,
        metric_names: List[str],
        mode: str
) -> Tuple[int, float]:
    """
    Prints trial status summary and the best metric found.
    :param tuning_status:
    :param metric_names:
    :param mode:
    :return: trial-id and value of the best metric found
    """
    if tuning_status.overall_metric_statistics.count == 0:
        return
    # only plot results of the best first metric for now in summary, plotting the optimal metrics for multiple
    # objectives would require to display the Pareto set.
    metric_name = metric_names[0]
    print("-" * 20)
    print(f"Resource summary (last result is reported):\n{str(tuning_status)}")
    if mode == 'min':
        metric_per_trial = [
            (trial_id, stats.min_metrics.get(metric_name, np.inf))
            for trial_id, stats in tuning_status.trial_metric_statistics.items()
        ]
        metric_per_trial = sorted(metric_per_trial, key=lambda x: x[1])
    else:
        metric_per_trial = [
            (trial_id, stats.max_metrics.get(metric_name, -np.inf))
            for trial_id, stats in tuning_status.trial_metric_statistics.items()
        ]
        metric_per_trial = sorted(metric_per_trial, key=lambda x: -x[1])
    best_trialid, best_metric = metric_per_trial[0]
    print(f"{metric_name}: best {best_metric} for trial-id {best_trialid}")
    print("-" * 20)
    return best_trialid, best_metric
