"""Generic hyperparameter tuning support."""

from datetime import datetime
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
import itertools
import json
import math
import os
import traceback
from typing import Callable, Iterable

from threading import Lock

HParamRanges = dict[str, list[int | float | str]]
HParams = dict[str, int | float | str]
TargetFn = Callable[[HParams], dict[str, float]]


class HyperparamTuner:
    """Hyperparameter tuning class.

    A tuning is multi-threaded and resumable. Results are logged to a file.
    The tuning algo is a simple grid search.
    """

    def __init__(
        self,
        hparam_ranges: HParamRanges,
        target_fn: TargetFn,
        opt_metric_name: str,
        log_file_path: str,
    ):
        """Initialize the hyperparameter tuner."""
        self.hparam_ranges = hparam_ranges
        self.hparam_names = sorted(hparam_ranges.keys())
        self.log_file_header = ["score", "hash", "time"] + self.hparam_names + ["log"]

        self.target_fn = self.wrap_target_fn(target_fn)
        self.opt_metric_name = opt_metric_name
        self.log_file_path = log_file_path

        # Map from hparam hash to score.
        self.hparam_scores: dict[int, float] = self._load_scores_from_log()

        self.lock = Lock()

    def wrap_target_fn(self, target_fn: TargetFn) -> TargetFn:
        """Wrap the target function to catch errors."""

        def wrapped_target_fn(hparams: HParams) -> float:
            try:
                return target_fn(hparams)
            except Exception as e:  # pylint: disable=broad-except
                print(f"Hyperparameter tuning failed for hparams: {hparams}")
                print(f"Error in target_fn: {e}")
                traceback.print_exception(e)
                return {self.opt_metric_name: float("-inf")}

        return wrapped_target_fn

    def grid_search(self, num_threads: int = 1) -> tuple[int, float]:
        """Perform a multi-threaded grid search of the hyperparameter space."""
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = {
                executor.submit(self.target_fn, params): params
                for params in self._generate_new_hparams()
            }
            print(f"Running #{len(futures)} experiments for hparam tuning.")
            for future in as_completed(futures):
                hparams = futures[future]
                log = future.result()
                score = log.get(self.opt_metric_name, math.inf)
                with self.lock:
                    self.hparam_scores[hash_hparams(hparams)] = score
                    self._write_score_to_log(score, hparams, log)

        best_hparam_hash = min(self.hparam_scores, key=self.hparam_scores.get)  # type: ignore
        return (best_hparam_hash, self.hparam_scores[best_hparam_hash])

    def _generate_new_hparams(self) -> Iterable[HParams]:
        keys, values = zip(*self.hparam_ranges.items())
        for hparam_values in itertools.product(*values):
            hparams = dict(zip(keys, hparam_values))
            is_already_tested = hash_hparams(hparams) in self.hparam_scores
            if not is_already_tested:
                yield hparams

    def _write_score_to_log(
        self, score: float, hparams: HParams, log: dict[str, float]
    ) -> None:
        with open(self.log_file_path, "a", newline="", encoding="utf-8") as file:
            writer = csv.writer(file)
            sorted_hparam_values = [hparams[name] for name in self.hparam_names]
            try:
                writer.writerow(
                    [score, hash_hparams(hparams), current_time_str()]
                    + sorted_hparam_values
                    + [json.dumps(log, sort_keys=True)]
                )
            except TypeError as e:
                for k, v in log.items():
                    try:
                        json.dumps(v)
                    except TypeError:
                        print(f"Error converting {k=} with {v=} to JSON.")
                raise e

    def _load_scores_from_log(self) -> dict[int, float]:
        if not os.path.isfile(self.log_file_path):
            self._initialize_log_file()
            return {}

        scores = {}
        with open(self.log_file_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            header = next(reader)
            if not header == self.log_file_header:
                raise ValueError(
                    f"Hparam tuning cannot be resumed at '{self.log_file_path}'. "
                    "Headers do not match: '{header}' vs '{expected_header}'."
                )
            for row in reader:
                score = float(row[0])
                hparam_hash = int(row[1])
                scores[hparam_hash] = score
        print(
            f"Log file found at '{self.log_file_path}'. "
            f"Resuming the previous session which has #{len(scores)} results logged."
        )
        return scores

    def _initialize_log_file(self) -> None:
        print(f"Creating a new log file at '{self.log_file_path}'.")
        with open(self.log_file_path, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            print(f"Log file header: {self.log_file_header}")
            writer.writerow(self.log_file_header)


def hash_hparams(hparams: HParams) -> int:
    """Deterministic hash for a hyperparameter configuration."""
    return hash(json.dumps(hparams, sort_keys=True))


def current_time_str() -> str:
    """Return the current time as a string for logging."""
    current_time = datetime.now()
    return current_time.strftime("%Y-%m-%d_%H:%M:%S")
