import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union, cast

import hydra
from omegaconf import DictConfig, OmegaConf

from pkg.data import BaseData
from pkg.logic.evaluation import Evaluation
from pkg.model import BaseModel
from pkg.utils.logging import CONFIG_FILE_NAME, save_snapshot_of_source_code
from pkg.utils.reproduce import save_config, seed_everything
from pkg.utils.setup import set_device
from train import TrainConfig

logger: logging.Logger = logging.getLogger(__name__)


@dataclass
class EvalConfig(DictConfig):
    evaluation: Any
    restore_from: str

    device: str
    batch_size: int
    debug: bool


@hydra.main(config_path="pkg/config", config_name="eval", version_base="1.3")
def main(config: EvalConfig) -> None:

    # config
    train_config = OmegaConf.load(Path(config.restore_from) / CONFIG_FILE_NAME)
    for key in list(config.keys()):  # assert that keys are unique
        assert (key not in list(train_config.keys())) or (key in ["device", "debug"])
    logger.info(f"Merge config with restored config from {config.restore_from}")
    logger.info(f"Config:\n{OmegaConf.to_yaml(config)}")
    logger.info(f"Restored config:\n{OmegaConf.to_yaml(train_config)}")
    merged_config = cast(
        Union[EvalConfig, TrainConfig], OmegaConf.merge(train_config, config)
    )
    save_config(config=merged_config)

    # save code
    # save_snapshot_of_source_code(file_name=__file__)

    # device
    device = set_device(device=merged_config.device)

    # data
    seed_everything(seed=merged_config.seed)
    data: BaseData = hydra.utils.instantiate(
        merged_config.data, data_path=merged_config.data_path
    )

    # model
    seed_everything(seed=merged_config.seed)
    model: BaseModel = hydra.utils.instantiate(
        merged_config.model,
        num_concepts=data.num_concepts,
        num_questions=data.num_questions,
        max_concepts=data.max_concepts,
        max_len=data.max_len,
    )
    model.restore(restore_dir=merged_config.restore_from)
    model.to(device)

    # dispatch to evaluation
    evaluation: Evaluation = hydra.utils.instantiate(
        merged_config.evaluation, data=data, model=model, device=device
    )
    evaluation.run()

    logger.info("Done")


if __name__ == "__main__":
    main()
