import enum
import glob
import json
import logging
import os
import pathlib
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import pydantic
from tensorboard.backend.event_processing import event_accumulator

# By default TB tries to be smart about what to load in memory to avoid OOM
# Since we expect every step to be there when we do our comparisons, we explicitly
# set the size guidance to 0 so that we load everything. It's okay given our tests
# are small/short.
SIZE_GUIDANCE = {event_accumulator.TENSORS: 0, event_accumulator.SCALARS: 0}

logger = logging.getLogger(__name__)


def approximate_threshold(rtol: float) -> Callable:
    def _func(y_pred: List[Union[float, int]], y_true: List[Union[float, int]]):
        return np.mean([np.mean(y_pred), np.mean(y_true)]) * rtol

    return _func


class TypeOfTestResult(enum.Enum):
    APPROXIMATE = 1
    DETERMINISTIC = 2


class Test(pydantic.BaseModel):
    pass


class NotApproximateError(Exception):
    """Raised if comparison is not within approximate bounds"""


class NotDeterminsticError(Exception):
    """Raised if comparison is not within approximate bounds"""


class ApproximateTest(Test):
    atol: Union[int, float] = 0
    atol_func: Optional[Callable] = None
    rtol: float = 1e-5

    @property
    def type_of_test_result(self) -> TypeOfTestResult:
        return TypeOfTestResult.APPROXIMATE

    def error_message(self, metric_name: str) -> NotApproximateError:
        return NotApproximateError(f"Approximate comparison of {metric_name}: FAILED")


class DeterministicTest(Test):
    @property
    def atol(self) -> Union[int, float]:
        return 0

    atol_func: Optional[Callable] = None

    @property
    def rtol(self) -> float:
        return 0.0

    @property
    def type_of_test_result(self) -> TypeOfTestResult:
        return TypeOfTestResult.DETERMINISTIC

    def error_message(self, metric_name: str) -> NotDeterminsticError:
        return NotDeterminsticError(f"Exact comparison of {metric_name}: FAILED")


class GoldenValueMetric(pydantic.BaseModel):
    start_step: int
    end_step: int
    step_interval: int
    values: Dict[int, Union[int, float, str]]

    def __repr__(self):
        return f"Values ({self.start_step},{self.end_step},{self.step_interval}): {', '.join([str(f'({step}, {value})') for step, value in self.values.items()])}"


class GoldenValues(pydantic.RootModel):
    root: Dict[str, GoldenValueMetric]


class MissingTensorboardLogsError(Exception):
    """Raised if TensorboardLogs not found"""


class UndefinedMetricError(Exception):
    """Raised of golden values metric has no test definition"""


class SkipMetricError(Exception):
    """Raised if metric shall be skipped"""


def read_tb_logs_as_list(
    path, index: int = 0, train_iters: int = 50, start_idx: int = 1, step_size: int = 5
) -> Optional[Dict[str, GoldenValueMetric]]:
    """Reads a TensorBoard Events file from the input path, and returns the
    summary specified as input as a list.

    Args:
        path: str, path to the dir where the events file is located.
        summary_name: str, name of the summary to read from the TB logs.

    Returns:
        summary_list: list, the values in the read summary list, formatted as a list.
    """
    files = glob.glob(f"{path}/events*tfevents*")
    files += glob.glob(f"{path}/results/events*tfevents*")

    if not files:
        logger.error(f"File not found matching: {path}/events* || {path}/results/events*")
        return None

    files.sort(key=lambda x: os.path.getmtime(os.path.join(path, pathlib.Path(x).name)))
    accumulators = []

    if index == -1:
        for event_file in files:
            ea = event_accumulator.EventAccumulator(event_file, size_guidance=SIZE_GUIDANCE)
            ea.Reload()
            accumulators.append(ea)
    else:
        event_file = files[index]
        ea = event_accumulator.EventAccumulator(event_file, size_guidance=SIZE_GUIDANCE)
        ea.Reload()
        accumulators.append(ea)

    summaries = {}
    for ea in accumulators:
        for scalar_name in ea.Tags()["scalars"]:
            if scalar_name in summaries:
                for x in ea.Scalars(scalar_name):
                    if x.step not in summaries[scalar_name]:
                        summaries[scalar_name][x.step] = round(x.value, 5)

            else:
                summaries[scalar_name] = {
                    x.step: round(x.value, 5) for x in ea.Scalars(scalar_name)
                }

    golden_values = {}

    for metric, values in summaries.items():
        # Add missing values
        values = {
            k: (values[k] if k in values else "nan")
            for k in range(1, train_iters + 1)
            if k == start_idx or (k > start_idx and int(k) % step_size == 0)
        }

        golden_values[metric] = GoldenValueMetric(
            start_step=min(values.keys()),
            end_step=max(values.keys()),
            step_interval=step_size,
            values=values,
        )

    return golden_values


def read_golden_values_from_json(
    golden_values_path: Union[str, pathlib.Path]
) -> Dict[str, GoldenValueMetric]:
    with open(golden_values_path) as f:
        if os.path.exists(golden_values_path):
            with open(golden_values_path) as f:
                return GoldenValues(**json.load(f)).root

        raise ValueError(f"File {golden_values_path} not found!")


def _filter_checks(
    checks: List[Union[ApproximateTest, DeterministicTest]], filter_for_type_of_check
):
    return [test for test in checks if test.type_of_test_result == filter_for_type_of_check]


def pipeline(
    compare_approximate_results: bool,
    golden_values: Dict[str, GoldenValueMetric],
    actual_values: Dict[str, GoldenValueMetric],
    checks: Dict[str, List[Union[ApproximateTest, DeterministicTest]]],
):
    all_test_passed = True
    failed_metrics = []

    for metric_name, metric_thresholds in checks.items():
        if metric_name not in list(actual_values.keys()):
            raise MissingTensorboardLogsError(
                f"Metric {metric_name} not found in Tensorboard logs! Please modify `model_config.yaml` to record it."
            )

        for test in metric_thresholds:
            if (
                compare_approximate_results
                and test.type_of_test_result == TypeOfTestResult.DETERMINISTIC
            ):
                continue

            try:
                golden_value = golden_values[metric_name]
                golden_value_list = list(golden_value.values.values())
                actual_value_list = [
                    value
                    for value_step, value in actual_values[metric_name].values.items()
                    if value_step in golden_value.values.keys()
                ]

                if metric_name == "iteration-time":
                    actual_value_list = actual_value_list[3:-1]
                    golden_value_list = golden_value_list[3:-1]
                    logger.info(
                        "For metric `%s`, the first 3 and the last scalars are removed from the list to reduce noise.",
                        metric_name,
                    )

                actual_value_list = [np.inf if type(v) is str else v for v in actual_value_list]
                golden_value_list = [np.inf if type(v) is str else v for v in golden_value_list]

                actual = np.array(actual_value_list)
                golden = np.array(golden_value_list)

                # Tolerance check
                passing = np.allclose(
                    actual,
                    golden,
                    rtol=test.rtol,
                    atol=(
                        test.atol_func(actual_value_list, golden_value_list)
                        if test.atol_func is not None
                        else test.atol
                    ),
                )

                if not passing:
                    logger.info("Actual values: %s", ", ".join([str(v) for v in actual_value_list]))
                    logger.info("Golden values: %s", ", ".join([str(v) for v in golden_value_list]))
                    raise test.error_message(metric_name)

                result = f"{test.type_of_test_result.name} test for metric {metric_name}: PASSED"
                result_code = 0

            except (NotApproximateError, NotDeterminsticError, MissingTensorboardLogsError) as e:
                result = str(e)
                result_code = 1
            except SkipMetricError:
                logger.info(f"{test.type_of_test_result.name} test for {metric_name}: SKIPPED")
                continue

            log_emitter = logger.info if result_code == 0 else logger.error
            log_emitter(result)
            if result_code == 1:
                all_test_passed = False
                failed_metrics.append(metric_name)

    assert all_test_passed, f"The following metrics failed: {', '.join(failed_metrics)}"
