import json
import logging
from pathlib import Path
from typing import Union

import torch

from pkg.data import BaseData
from pkg.data.pykt.eval_dataset_format_combinatorial_dense import (
    eval as eval_combinatorial_dense,
)
from pkg.data.pykt.eval_dataset_format_set_dense import eval as eval_set_dense
from pkg.logic.trainer import State
from pkg.model import BaseModel
from pkg.utils.logging import BEST_FILE_NAME, Timer, get_checkpoint_path_without_suffix

logger: logging.Logger = logging.getLogger(__name__)


class Evaluation:
    def __init__(
        self,
        data: BaseData,
        model: BaseModel,
        device: torch.device,
        restore_from: str,
        batch_size: int,
    ) -> None:
        self.restore_from = restore_from
        self.data = data
        self.model = model
        self.device = device
        self.eval_modes = ["test"]
        self.batch_size = batch_size

        if self.data.format == "expanded":
            raise NotImplementedError("Use `pykt` for that ...")

        # logging
        self.trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        logger.info(f"Model:\n{self.model}\nTrainable params: {self.trainable_params}")

    def run(self):
        self.restore_state(self.restore_from)

        logger.info("Run ...")

        result = dict()
        for mode in self.eval_modes:
            result.update(self.eval(mode=mode))
        self.log_results(result=result)

    def eval(self, mode: str = "test") -> dict:
        self.model.eval()
        timer = Timer()

        assert mode == "test"

        # Only pykt evaluation modes are supported
        with torch.no_grad():
            if self.data.format == "set_dense":
                assert isinstance(self.data.test_file, Path)
                auc, acc, num_evaluated_questions = eval_set_dense(
                    model=self.model,
                    test_file=self.data.test_file,
                    max_concepts=self.data.max_concepts,
                    swap_q_and_c=self.data.swap_q_and_c,
                    batch_size=self.batch_size,
                )
            elif self.data.format == "combinatorial_dense":
                assert isinstance(self.data.test_file, Path)
                assert self.data.original_max_concepts is not None
                assert self.data.unique_concept_mapping is not None
                auc, acc, num_evaluated_questions = eval_combinatorial_dense(
                    model=self.model,
                    test_file=self.data.test_file,
                    original_max_concepts=self.data.original_max_concepts,
                    unique_concept_mapping=self.data.unique_concept_mapping,
                    swap_q_and_c=self.data.swap_q_and_c,
                    batch_size=self.batch_size,
                )
            else:
                raise ValueError(f"Data format {self.data.format} is not supported.")

        result = {
            f"{mode}_time": timer.get_duration_in_seconds(),
            f"{mode}_metric": auc,
            f"{mode}_auc": auc,
            f"{mode}_accuracy": acc,
            f"{mode}_num_evaluated_questions": num_evaluated_questions,
            f"best_iteration": self.state.best_iteration,
            f"best_val_metric": self.state.best_val_metric,
        }

        return result

    def restore_state(
        self, restore_dir: str, iteration: Union[str, int] = "best"
    ) -> None:
        logger.info(f"Restoring state from {restore_dir}")
        checkpoint_path_without_suffix = get_checkpoint_path_without_suffix(
            dir=restore_dir, iteration=iteration
        )
        state_path = checkpoint_path_without_suffix.with_suffix(".yaml")
        with open(file=state_path, mode="r") as file:
            state_dict: dict = json.load(fp=file)
        logger.info(f"With {state_dict=}")
        self.state = State(**state_dict)

    def log_results(self, result: dict[str, float]) -> None:
        result.update(
            epoch=self.state.epoch,
            iteration=self.state.iteration,
            trainable_params=self.trainable_params,
        )

        # Write to json + log
        result_json: str = json.dumps(result, skipkeys=True)
        with open(file=BEST_FILE_NAME, mode="a") as file:
            file.write(f"{result_json}\n")
        logger.info(result_json)
