import torch
import torch.nn as nn
from torch.utils.data import Dataset
import json
import os
import sys
from HuffPost import *


if __name__ == '__main__':
    start  = '2012-02-01'
    end = '2017-12-31'

    dataset = HuffPost(format_check=True, t_start = start, t_end = end)

    dataset.preprocess_RobertaBase()
    #dataset.load_embedding_RobertaBase()
    print (dataset[0][0].shape)
    print (dataset[0][1].shape)

    tot = 0

    for year in range(2012, 2022):
        for month in range(1, 13):
            t_start = '%04d-%02d-01'%(year, month)
            t_end = '%04d-%02d-32'%(year, month)

            if (t_start < start or t_end > end):
                continue
                
            idx_range = dataset.set_range_date(t_start, t_end)
            tot += idx_range[1] - idx_range[0]
            print (t_start, t_end, idx_range[1] - idx_range[0])
    print ('tot_remain', tot)


    dataset.set_range_date(start, end)
    dataset.format_check()
