import torch
from dataset.HuffPost import HuffPost
from learner.learner import Learner
import argparse
import os

from init import *
dataset.load_embedding_RobertaBase()

parser = argparse.ArgumentParser(description='')
parser.add_argument('--n_month_train', default = 12, type=int, help='')
parser.add_argument('--name', default = 'tmp', type=str, help='')
parser.add_argument('--init_from_last', action='store_true')
parser.add_argument('--epochs', default=50, type=int)

args = parser.parse_args()

checkpoint_dir = './checkpoints'

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

save_path = checkpoint_dir + '/' + args.name + '.t7'

learner = Learner()

for month_end in range(len(months)):
    print ('[month_end: %d]'%(month_end, ), flush = True)
    
    ##training
    n_month_train = args.n_month_train
    month_start = max(0, month_end - n_month_train + 1)

    inputs = 'headline'
    model_type = 'linear'

    #training set
    dataset.set_range_date(months[month_start][0], months[month_end][1])

    learner.train(dataset, key = str(month_end), 
                init_key = str(month_end - 1) if (month_end > 0) and (args.init_from_last) else '',
                seed = month_end, input_dim = 768,
                n_class = 41, 
                inputs = inputs,
                batch_size = 256,
                epochs = args.epochs,
                lr = 1e-3,
                wd = 5e-4,
                optimizer_name = 'AdamW',
                normalize = True,
                model_type = model_type
            )

    print ('[train]' + learner.eval(dataset, key = str(month_end),
                                input_dim = 768,
                                inputs = inputs, normalize = True, 
                                model_type = model_type))

    #test_set
    dataset.set_range_date(months[0][0], months[-1][1])
    print ('[eval]' + learner.eval(dataset, key = str(month_end),
                               input_dim = 768,
                               inputs = inputs, normalize = True,
                               model_type = model_type, save = True))



torch.save(learner, save_path)
