# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import treetable as tt

from .._base_explorers import BaseExplorer


class LMExplorer(BaseExplorer):
    eval_metrics: tp.List[str] = []

    def stages(self) -> tp.List[str]:
        return ["train", "valid"]

    def get_grid_metrics(self):
        """Return the metrics that should be displayed in the tracking table."""
        return [
            tt.group(
                "train",
                [
                    tt.leaf("epoch"),
                    tt.leaf("duration", ".1f"),  # duration in minutes
                    tt.leaf("ping"),
                    tt.leaf("ce", ".4f"),  # cross entropy
                    tt.leaf("ppl", ".3f"),  # perplexity
                ],
                align=">",
            ),
            tt.group(
                "valid",
                [
                    tt.leaf("ce", ".4f"),
                    tt.leaf("ppl", ".3f"),
                    tt.leaf("best_ppl", ".3f"),
                ],
                align=">",
            ),
        ]

    def process_sheep(self, sheep, history):
        parts = super().process_sheep(sheep, history)

        track_by = {"ppl": "lower"}  # values should be in ['lower', 'higher']
        best_metrics = {
            k: (1 if v == "lower" else -1) * float("inf") for k, v in track_by.items()
        }

        def comparator(mode, a, b):
            return a < b if mode == "lower" else a > b

        for metrics in history:
            for key, sub in metrics.items():
                for metric in track_by:
                    # for the validation set, keep track of best metrics (ppl in this example)
                    # this is so we can conveniently compare metrics between runs in the grid
                    if (
                        key == "valid"
                        and metric in sub
                        and comparator(
                            track_by[metric], sub[metric], best_metrics[metric]
                        )
                    ):
                        best_metrics[metric] = sub[metric]

        if "valid" in parts:
            parts["valid"].update({f"best_{k}": v for k, v in best_metrics.items()})
        return parts


class GenerationEvalExplorer(BaseExplorer):
    eval_metrics: tp.List[str] = []

    def stages(self) -> tp.List[str]:
        return ["evaluate"]

    def get_grid_metrics(self):
        """Return the metrics that should be displayed in the tracking table."""
        return [
            tt.group(
                "evaluate",
                [
                    tt.leaf("epoch", ".3f"),
                    tt.leaf("duration", ".1f"),
                    tt.leaf("ping"),
                    tt.leaf("ce", ".4f"),
                    tt.leaf("ppl", ".3f"),
                    tt.leaf("fad", ".3f"),
                    tt.leaf("kld", ".3f"),
                    tt.leaf("text_consistency", ".3f"),
                    tt.leaf("chroma_cosine", ".3f"),
                ],
                align=">",
            ),
        ]
