import torch
import torch.nn as nn
import torch.utils.data
from tqdm import tqdm

from .utils import to_device
from src.models import DynamicModel


class Trainer:
    def __init__(self, model, optim, scheduler, train_loader, logger, device):
        self.model = model
        self.optim = optim
        self.scheduler = scheduler
        self.device = torch.device(device)
        self.train_loader = train_loader
        self.logger = logger

        # get loss function
        self.criterion = self.model.criterion

        self.global_step = 0


    def train(self, train_steps, log_freq=100, save_freq=10000):
        self.model.train()
        for epoch in range(1, 100):
            for batch in tqdm(self.train_loader):
                batch = to_device(batch, self.device)
                output = self.model(batch)
                loss, info = self.criterion(tuple(batch), tuple(output))
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.scheduler.step()

                self.global_step += 1

                if self.global_step % save_freq == 0:
                    torch.save(self.model.state_dict(), f'{self.logger.checkpoint_dir}/model_{self.global_step}.pth')
                            
                for k, v in info.items():
                    self.logger.logkv_mean('train/'+k, v)

                if self.global_step % log_freq == 0 :
                    self.logger.set_timestep(self.global_step)
                    self.logger.dumpkvs()
                
                if self.global_step == train_steps:
                    torch.save(self.model.state_dict(), f'{self.logger.model_dir}/model.pth')
                    return
