import torch
import torch.nn as nn
import torch.utils.data

from .utils import to_device


class DeepSpeedTrainer:
    def __init__(self, model_engine, trainloader, logger):
        self.model_engine = model_engine
        self.local_rank = model_engine.local_rank
        self.device = torch.device(f'cuda:{self.local_rank}')
        self.trainloader = trainloader
        self.logger = logger

        # get loss function
        self.criterion = self.model_engine.module.criterion

        self.global_step = 0

    def train(self, train_steps, log_freq=200, save_freq=5000):
        self.model_engine.train()
        for epoch in range(1, 10):
            for batch in self.trainloader:
                batch = to_device(batch, self.device)
                output = self.model_engine(batch)
                loss, info = self.criterion(tuple(batch), tuple(output))
                self.model_engine.backward(loss)
                self.model_engine.step()

                if self.local_rank == 0:
                    for k, v in info.items():
                        self.logger.logkv_mean(k, v)

                self.global_step += 1

                if (self.global_step % log_freq == 0) and (self.local_rank == 0):
                    self.logger.set_timestep(self.global_step)
                    self.logger.dumpkvs()

                if (self.global_step % save_freq == 0):
                    client_sd = {'step': self.global_step}
                    self.model_engine.save_checkpoint(self.logger.checkpoint_dir, tag=self.global_step, client_state=client_sd)
                    torch.save(self.model_engine.module.state_dict(), f'{self.logger.model_dir}/model.pth')

                if self.global_step == train_steps:
                    if self.local_rank == 0:
                        torch.save(self.model_engine.module.state_dict(), f'{self.logger.model_dir}/model.pth')
                    return
        