import os
import warnings
from datetime import datetime
from typing import Any, Dict, Tuple

import yaml
from optax import OptState

from egxc.dataloading.io import (
    checkpoint_best_path,
    checkpoint_config_path,
    checkpoint_directory,
    checkpoint_step_path,
    hash_dictionary,
    pickle_dictionary,
    unpickle_dictionary,
)
from egxc.utils.typing import NnParams

EmaState = Any


class CheckpointManager:
    def __init__(
        self, directory: str, model_name: str, basis: str, name: str, data_split_seed: int
    ):
        self.directory = checkpoint_directory(
            directory,
            model_name,
            basis,
            name,
            data_split_seed,
        )
        if os.path.exists(self.directory):
            timestamp = datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')
            warnings.warn(
                f'Checkpoint directory {self.directory} already exists. \
                Renaming to {self.directory}_{timestamp}'
            )
            self.directory = f'{self.directory}_{timestamp}'
        os.makedirs(self.directory)
        self.name = name

    def save_params(self, params: NnParams, step: int, prefix: str = ''):
        path = checkpoint_step_path(self.directory, step, prefix)
        pickle_dictionary(params, path)  # type: ignore

    def save_best_params(self, params: NnParams, prefix: str = ''):
        path = checkpoint_best_path(self.directory, prefix)
        pickle_dictionary(params, path)  # type: ignore

    @staticmethod
    def load_params(
        directory: str, step: int | None = None, prefix: str = ''
    ) -> NnParams:
        """
        Loads the parameters from the checkpoint.
        If step is None, loads the best parameters.
        If step is not None, loads the parameters for the given step.
        """
        path = (
            checkpoint_best_path(directory, prefix)
            if step is None
            else checkpoint_step_path(directory, step, prefix)
        )
        if not os.path.exists(path):
            raise FileNotFoundError(f'Checkpoint not found at {path}')
        return unpickle_dictionary(path)

    def load_config(self) -> Dict[str, Any]:
        config_path = checkpoint_config_path(self.directory, self.name)
        with open(config_path, 'r') as f:
            return yaml.load(f, Loader=yaml.FullLoader)

    def save_optimizer_state(self, state: Tuple[OptState, EmaState], step: int):
        raise NotImplementedError('Not implemented')

    def load_optimizer_state(self, step: int) -> Tuple[OptState, EmaState]:
        raise NotImplementedError('Not implemented')

    def save_config(self, config: Dict[str, Any]):
        config_path = checkpoint_config_path(self.directory, self.name)
        config['hash'] = hash_dictionary(config)
        with open(config_path, 'w') as f:
            yaml.dump(config, f)
