from tqdm import tqdm
import torch
from topmost.utils import static_utils

from models.FASTopic import FASTopic


class Trainer:
    def __init__(self, args):
        self.args = args
        self.epochs = args.training.epochs
        self.learning_rate = args.training.learning_rate

        self.model = FASTopic(args)
        self.model = self.model.to(args.device)

    def make_optimizer(self,):
        args_dict = {
            'params': self.model.parameters(),
            'lr': self.learning_rate,
        }

        optimizer = torch.optim.Adam(**args_dict)
        return optimizer

    def train(self, dataset_handler, verbose=False):
        optimizer = self.make_optimizer()

        self.model.train()
        for epoch in tqdm(range(1, self.epochs + 1)):

            rst_dict = self.model(dataset_handler.train_data)
            batch_loss = rst_dict['loss']

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            if verbose and epoch % 5 == 0:
                output_log = f'Epoch: {epoch:03d}'
                for key in rst_dict:
                    output_log += f' {key}: {rst_dict[key]:.3f}'

                print(output_log)

    def test(self, input_data):
        with torch.no_grad():
            self.model.eval()
            theta = self.model.get_theta(input_data)
            theta = theta.detach().cpu().numpy()

        return theta

    def export_beta(self):
        beta = self.model.get_beta().detach().cpu().numpy()
        return beta

    def export_top_words(self, vocab, num_top_words=15):
        beta = self.export_beta()
        top_words = static_utils.print_topic_words(beta, vocab, num_top_words)
        return top_words

    def export_theta(self, dataset_handler):
        train_theta = self.test(dataset_handler.train_texts)
        test_theta = self.test(dataset_handler.test_texts)
        return train_theta, test_theta
