import os
import time
import warnings

import torch
from tqdm import tqdm

from config import Config

class RewardModelPipeline:
    """
 unified way to interweave training and validation for reward models.
    """
    def __init__(self,
                 model_trainer,
                 model_validator,
                 model_name,
                 validate_every: int=1,
                 checkpoint_every: int=1,
    ):
        self.model_trainer = model_trainer
        self.model_validator = model_validator
        self.model_name = model_name
        self.validate_every = validate_every
        self.checkpoint_every = checkpoint_every
        self.current_epoch = 0

    def minibatch(self, state: "Board", reward_array: float) -> None:
        """
        :param state: The board representation of the state.
        :param action: The action index.
        :param reward: The reward for the (state, action) pair.
        """
        self.model_trainer.minibatch(state, reward_array)
        self.model_validator.minibatch(state, reward_array)

    def epoch(self) -> None:
        self.model_trainer.epoch()
        self.model_validator.epoch()

    def n_epochs(self, n) -> None:
        config = Config()
        for i in tqdm(range(self.current_epoch, n)):
            self.current_epoch += 1
            self.model_trainer.epoch()
            if not i % self.validate_every:
                self.model_validator.epoch()
            if not i % self.checkpoint_every:
                print('checkpointing')
                torch.save({
                    'model_params': self.model_trainer.model.state_dict(),
                    'optimizer_params': self.model_trainer.optimizer.state_dict(),
                    'current_epoch': self.current_epoch,
                },
                os.path.join(config.checkpoints_dir,
                            f'{self.model_name}.checkpoint'))

    def load_checkpoint(self) -> None:
        config = Config()
        try:
            checkpoint = torch.load(
                os.path.join(config.checkpoints_dir,
                            f'{self.model_name}.checkpoint'))
            self.current_epoch = checkpoint['current_epoch']
            self.model_trainer.model.load_state_dict(checkpoint['model_params'])
            self.model_trainer.optimizer.load_state_dict(checkpoint['optimizer_params'])
        except OSError:
            warnings.warn(f'no checkpoint available for {self.model_name}; loading nothing...')

    def cleanup(self, save_model_filename: str) -> None:
        """
        Saves the model and runs all of the callback cleanup functions.
        :param save_model_filename: The filename to store the model at.
        """
        self.model_trainer.cleanup(save_model_filename)
        self.model_validator.cleanup()
