import os
import torch
import argparse
import scipy.io
from topmost.utils import static_utils
from utils.data.dataset import BasicDatasetHandler
from utils.data import file_utils
from trainers.trainer import Trainer


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_config', '-m')
    parser.add_argument('--dataset', '-d', default='20NG')
    parser.add_argument('--num_topics', '-k', type=int, default=50)
    parser.add_argument('--num_top_words', type=int, default=15)
    parser.add_argument('--test_index', type=int, default=0)
    args = parser.parse_args()
    return args


def main():
    DATASET_DIR = './data/'

    args = parse_args()
    file_utils.update_args(args, path=f'./configs/{args.model_config}.yaml')

    output_prefix = f'output/{args.dataset}/{args.model_config}_K{args.num_topics}_{args.test_index}th'
    file_utils.make_dir(os.path.dirname(output_prefix))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataset = BasicDatasetHandler(f'{DATASET_DIR}/{args.dataset}', device=device)

    args.device = device
    args.vocab_size = dataset.vocab_size
    args.train_texts = dataset.train_texts

    trainer = Trainer(args)
    trainer.train(dataset)

    beta = trainer.export_beta()
    top_words = static_utils.print_topic_words(beta, dataset.vocab, args.num_top_words)

    file_utils.save_text(top_words, f'{output_prefix}_T{args.num_top_words}')

    train_theta, test_theta = trainer.export_theta(dataset)

    params = {
        'beta': beta,
        'train_theta': train_theta,
        'test_theta': test_theta,
    }

    scipy.io.savemat(f'{output_prefix}_params.mat', params)


if __name__ == '__main__':
    main()
