import logging
import pickle
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any

import jax
import yaml
from jax import Array
from wonderwords import RandomWord

from neural_pfaffian.systems import Systems
from neural_pfaffian.vmc import VMCState

_CRITICAL_CONFIG_KEYS = ['wave_function', 'vmc', 'pretraining', 'seed']
"""For two configs to be considered exchangeable, they must have the same values for the following keys"""


@dataclass
class Checkpoint:
    vmc_state: VMCState
    systems: Systems
    aux_data: dict[str, Array]

    def __post_init__(self):
        # After initialization, move all data to the CPU
        jax.device_get(self.vmc_state)
        jax.device_get(self.systems)
        jax.device_get(self.aux_data)

    def save(self, path: Path, allow_overwrite: bool = False):
        # Ensure directory exists
        path.parent.mkdir(parents=True, exist_ok=True)
        # Check if file exists
        if path.exists() and not allow_overwrite:
            raise FileExistsError(f'File exists: {path}')

        # Delete the current folder at path
        if path.exists():
            shutil.rmtree(path)

        # Create the checkpoint folder
        path.mkdir()

        # Save each property individually
        with open(path / 'vmc_state.pkl', 'wb') as f:
            pickle.dump(self.vmc_state, f)
        with open(path / 'systems.pkl', 'wb') as f:
            pickle.dump(self.systems, f)

        # Create a folder for aux_data
        aux_data_path = path / 'aux_data'
        aux_data_path.mkdir()
        for k, v in self.aux_data.items():
            with open(aux_data_path / f'{k}.pkl', 'wb') as f:
                pickle.dump(v, f)

    @staticmethod
    def load(path: Path) -> 'Checkpoint':
        # Check whether the checkpoint exists
        if not path.exists():
            raise FileNotFoundError(f'Checkpoint not found: {path}')

        # Load each property individually
        with open(path / 'vmc_state.pkl', 'rb') as f:
            vmc_state = pickle.load(f)
        with open(path / 'systems.pkl', 'rb') as f:
            systems = pickle.load(f)

        aux_data = {}
        aux_data_path = path / 'aux_data'
        for aux_file in aux_data_path.iterdir():
            with open(aux_file, 'rb') as f:
                aux_data[aux_file.stem] = pickle.load(f)

        return Checkpoint(vmc_state, systems, aux_data)

    @property
    def step(self) -> int:
        return self.vmc_state.step.item()


class CheckpointManager:
    """Enables saving and loading checkpoints for a run.
    Example config:
    ```yaml
          checkpoint:
            checkpoint_interval: 100
            max_num_checkpoints: 5
            base_dir: checkpoints
    ```
    With this config checkpoints are automatically saved.

    For continuing runs, a `run_id` can be specified. If the run exists, the checkpoint manager will
    continue from the last checkpoint. If the run does not exist, a new run will be created.

    Additionally, you can specify a specific checkpoint to continue from by setting `continue_run_from`.
    This can be either 'best', 'last', or the name of a custom checkpoint.

    By default, a custom checkpoint called `pretrained` is saved after pretraining.

    If the run was logged with wandb, the run will be resumed in the same wandb run.
    """

    checkpoint_interval: int
    base_dir: Path
    run_id: str
    run_path: Path
    max_num_checkpoints: int
    """Number of periodic checkpoints to keep. The best checkpoint is always kept."""
    best_loss: float = float('inf')
    """Best loss encountered so far. Used to determine the best checkpoint."""
    continue_run_from: str | None = None
    wandb_id: str | None = None
    resume_wandb_run: bool = True

    def __init__(
        self,
        base_dir: Path | str,
        config: dict[str, Any],
        run_id: str | None = None,
        systems: Systems | None = None,
        max_num_checkpoints: int = 5,
        checkpoint_interval: int = 100,
        continue_run_from: str = 'last',
        resume_wandb_run: bool = True,
        **_,
    ):
        # Basic properties
        self.checkpoint_interval = checkpoint_interval
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)
        self.max_num_checkpoints = max_num_checkpoints

        self.run_id = self._get_run_id(run_id, systems)
        self.run_path = self.base_dir / self.run_id

        # Check if the run exists
        if self.run_path.exists():
            logging.info(f'Run exists: {self.run_id}, continuing run')
            self._init_from_existing_run_path(
                self.run_path,
                config,
                continue_run_from,
                resume_wandb_run,
            )
        else:
            # Create a fresh run
            self.run_path.mkdir()
            with open(self.run_path / 'config.yaml', 'w') as f:
                yaml.dump(config, f)

    def _init_from_existing_run_path(
        self,
        run_path: Path,
        config: dict[str, Any],
        continue_run_from: str,
        resume_wandb_run: bool,
    ):
        # Load the best loss
        best_path = run_path / 'best'
        if best_path.exists():
            with open(best_path / 'best_loss.txt') as f:
                self.best_loss = float(f.read())

        # Load the wandb id
        if resume_wandb_run:
            try:
                with open(self.run_path / 'wandb_id.txt') as f:
                    self.wandb_id = f.read()
            except FileNotFoundError:
                logging.warning("Tried to load wandb id, but it wasn't found")
                pass

        # Check if saved config matches the current config
        with open(self.run_path / 'config.yaml') as f:
            saved_config = yaml.safe_load(f)

        if not self.verify_config(config, saved_config):
            logging.warning(
                'Saved config does not match current config. '
                'Setups may not be compatible. Expect errors. '
                'The current config will not be saved to the checkpoint. '
                'The current config will only be used for this run.',
            )

        self.continue_run_from = continue_run_from

    def _get_run_id(self, run_id: str | None, systems: Systems | None) -> str:
        if run_id is None:
            # Use a timestamp as the checkpoint name
            time_str = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
            w = RandomWord()
            if systems is not None:
                return f'chkpt_{systems!s}_{w.word(word_max_length=6)}_{time_str}'
            return f'chkpt_{time_str}'

        logging.info(f'Got run_id from config: {run_id}')
        return run_id

    @staticmethod
    def verify_config(left, right) -> bool:
        for crit_key in _CRITICAL_CONFIG_KEYS:
            if left.get(crit_key) != right.get(crit_key):
                return False
        return True

    @property
    def continue_previous_run(self) -> bool:
        return self.continue_run_from is not None

    def should_save(self, step: int) -> bool:
        return step % self.checkpoint_interval == 0

    def save_wandb_id(self, wandb_id: str):
        self.wandb_id = wandb_id
        # Save the wandb id to a separate file
        with open(self.run_path / 'wandb_id.txt', 'w') as f:
            f.write(wandb_id)

    def save(self, checkpoint: Checkpoint, loss: float | None = None):
        checkpoint_name = f'{checkpoint.step:08d}'
        save_path = self.run_path / checkpoint_name
        checkpoint.save(save_path, allow_overwrite=True)

        if loss is not None and loss < self.best_loss:
            self.best_loss = loss
            best_path = self.run_path / 'best'
            checkpoint.save(best_path, allow_overwrite=True)
            # Save the best loss to a separate file
            with open(best_path / 'best_loss.txt', 'w') as f:
                f.write(str(loss))

        # Remove old checkpoints
        self.remove_old_checkpoints()

    def custom_save(self, checkpoint: Checkpoint, name: str):
        save_path = self.run_path / 'custom_chkpts' / name
        checkpoint.save(save_path)

    def remove_old_checkpoints(self):
        checkpoints = self.get_saved_checkpoints(include_best=False)

        if len(checkpoints) <= self.max_num_checkpoints:
            return

        # Remove the oldest checkpoints
        for checkpoint in checkpoints[: -self.max_num_checkpoints]:
            shutil.rmtree(checkpoint)

    def get_saved_checkpoints(self, include_best: bool = False) -> list[Path]:
        checkpoints = [
            c
            for c in sorted(self.run_path.iterdir())
            if c.is_dir() and c.name not in ['custom_chkpts', 'evaluation']
        ]

        if not include_best:
            checkpoints = [c for c in checkpoints if c.name != 'best']
        return checkpoints

    def load_best(self) -> Checkpoint:
        best_path = self.run_path / 'best'
        # Check if the best checkpoint exists
        if not best_path.exists():
            raise FileNotFoundError('Best checkpoint not found')

        return Checkpoint.load(best_path)

    def load_last(self) -> Checkpoint:
        checkpoints = self.get_saved_checkpoints()
        if not checkpoints:
            raise FileNotFoundError('No checkpoints found')
        last_checkpoint = checkpoints[-1]
        return Checkpoint.load(last_checkpoint)

    def load(self) -> Checkpoint:
        if self.continue_run_from == 'best':
            logging.info('Continuing from best checkpoint')
            return self.load_best()
        if self.continue_run_from == 'last':
            logging.info('Continuing from last checkpoint')
            return self.load_last()
        for p in (self.run_path / 'custom_chkpts').iterdir():
            if p.name == self.continue_run_from:
                logging.info(
                    f"Continuing from custom checkpoint '{self.continue_run_from}'",
                )
                return Checkpoint.load(p)
        raise FileNotFoundError(f'Checkpoint {self.continue_run_from} not found')
