import time
from collections import defaultdict

import torch

from gflownet.utils import Logger
from gflownet.algo import BaseAlgorithm



class Trainer:
    def __init__(self, model, dataloader, algo: BaseAlgorithm, logger: Logger):
        self.model = model
        self.dataloader = dataloader
        self.algo = algo
        self.logger = logger

        self.callbacks = defaultdict(list)

        self.iter_num = 0
        self.iter_start = 0.0
        self.iter_time = 0.0
        self.iter_dt = 0.0


    def add_callback(self, on_event: str, callback):
        self.callbacks[on_event].append(callback)

    def set_callback(self, on_event: str, callback):
        self.callbacks[on_event] = [callback]

    def trigger_callbacks(self, on_event: str):
        for callback in self.callbacks.get(on_event, []):
            callback(self)

    def run(self, iters=1, device='cpu', print_every=1):
        max_iters = self.iter_num + iters
        self.model.train()
        self.model.to(device)
        
        self.iter_time = self.iter_start = time.time()
        
        for batch in self.dataloader:

            
            info = self.algo.update(self.model, batch)
            self.trigger_callbacks('on_batch_end')
            
            self.iter_num += 1
            tnow = time.time()
            self.iter_dt = tnow - self.iter_time
            self.iter_time = tnow

            if self.iter_num % print_every == 0:
                self.logger.info('-'*50)
                self.logger.info(f'iteration: {self.iter_num}, ctime: {time.strftime("%Y-%m-%d %H:%M:%S")}')
                self.logger.log_dict(info)
                self.logger.info(f'training_time: {self.iter_dt}')

            if self.iter_num >= max_iters:
                self.logger.info('-'*50)
                self.trigger_callbacks('on_train_end')
                self.logger.info(f'total_training_time: {time.time() - self.iter_start}')
                break

    def save(self, path):
        torch.save({
            'iteration': self.iter_num,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.algo.optimizer.state_dict()
        }, path)

    def load(self, path, map_location='cpu'):
        state_dicts = torch.load(path, map_location=map_location)
        self.iter_num = state_dicts['iteration']
        self.model.load_state_dict(state_dicts['model_state_dict'])
        self.algo.optimizer.load_state_dict(state_dicts['optimizer_state_dict'])