import torch

class Trainer:
    def __init__(self): self.global_step = 0
    
    def set_epochs           (self, value): self.epochs           = value; return self
    def set_etv              (self, value): self.etv              = value; return self
    def set_train_dataloader (self, value): self.train_dataloader = value; return self
    def set_valid_dataloader (self, value): self.valid_dataloader = value; return self
    def set_test_dataloaer   (self, value): self.test_dataloader  = value; return self
    def set_model            (self, value): self.model            = value; return self
    def set_optimizer        (self, value): self.optimizer        = value; return self
    def set_scheduler        (self, value): self.scheduler        = value; return self
    def set_batch_size       (self, value): self.batch_size       = value; return self
    def set_loss_fn          (self, value): self.loss_fn          = value; return self
    def set_shuffle          (self, value): self.shuffle          = value; return self
    def set_drop_last        (self, value): self.drop_last        = value; return self
    def set_train_callback   (self, value): self.train_callback   = value; return self
    def set_valid_callback   (self, value): self.valid_callback   = value; return self
    def set_test_callback    (self, value): self.test_callback    = value; return self

    @torch.no_grad()
    def valid(self):
         self.model.eval()
         for step, data in enumerate(self.valid_dataloader):
             result = self.model(**data) 
             self.valid_callback.step(locals())
         self.valid_callback.epoch(locals())
         self.model.train()

    @torch.no_grad()
    def test(self):
         self.model.eval()
         for step, data in enumerate(self.test_dataloader):
             result = self.model(**data) 
             self.test_callback.step(locals())
         self.test_callback.epoch(locals())
         self.model.train()

    def run(self):
        for epoch in range(self.epochs):
            for step,data in enumerate(self.train_dataloader):
                self.global_step += 1
                self.optimizer.zero_grad()
                result = self.model(**data) 
                loss   = self.loss_fn(**(result | data)) 
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.train_callback.step(locals())

            if epoch % self.etv == 0: self.valid()
            self.train_callback.epoch(locals())

        self.test()




