import torch
import torch.nn as nn
from torch.utils.data import Dataset
import json
import os
import sys

HuffPost_path = './data/News_Category_Dataset_v3.json'
embedding_path = './data/embedding/roberta.t7'
embedding_path_all_token = './data/embedding/roberta_all_token.t7'

torch.hub.set_dir('')

    


device = "cuda" if torch.cuda.is_available() else "cpu"

class HuffPost(Dataset):

    def __init__(self, path = HuffPost_path, format_check = False, t_start = '2012-02-01', t_end = '2022-09-31'):
        self.data = []
        self.labels = {}
        self.n_class = 0
        with open(path, 'r') as f:
        
            for line in f:
                entry = json.loads(line)
                
                if (entry['date']<t_start or entry['date']>t_end):
                    continue

                self.data.append(entry)
                
                if self.data[-1]['category'] not in self.labels:
                    self.labels[self.data[-1]['category']] = self.n_class
                    self.n_class += 1

            self.data.reverse()


        self.n_tot = len(self.data)
        self.idx_start = 0
        self.idx_end = self.n_tot
        

        if format_check:
            self.format_check()
        
    def format_check(self):
        print ("tot_sample: ", self.idx_end - self.idx_start)

        print ("tot_class: ", self.n_class)
        print (self.labels)
        for i in range(self.n_tot - 1):
            if (self.data[i]['date'] > self.data[i+1]['date']):
                raise Exception("Note that the dates are not sorted!")

        cnt = [0 for i in range(self.n_class)]

        for i in range(self.idx_start, self.idx_end):
            cnt[self.labels[self.data[i]['category']]] += 1

        print (cnt)

    
    def preprocess_RobertaBase(self, all_token = False):
    
        roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
        roberta.eval()
        roberta.to(device)

        self.embedding_headline = []
        self.embedding_short_description = []

        for idx in range(self.n_tot):
            tokens = roberta.encode(self.data[idx]['headline'])
            last_layer_features = roberta.extract_features(tokens)
            if not all_token:
                last_layer_features = last_layer_features[:, :1, :]

            self.embedding_headline.append(last_layer_features.detach().cpu())
            
            tokens = roberta.encode(self.data[idx]['short_description'])
            last_layer_features = roberta.extract_features(tokens)
            if not all_token:
                last_layer_features = last_layer_features[:, :1, :]
            self.embedding_short_description.append(last_layer_features.detach().cpu())

            if (idx + 1)% 1000 == 0:
                print ("%d/%d"%(idx+1, self.n_tot), flush=True)


        torch.save(
            {
                'embedding_headline':self.embedding_headline,
                'embedding_short_description':self.embedding_short_description,
            },
            embedding_path if not all_token else embedding_path_all_token)


    def load_embedding_RobertaBase(self):
        embedding = torch.load(embedding_path)
        self.embedding_headline = embedding['embedding_headline']
        self.embedding_short_description = embedding['embedding_short_description']    

    def set_range_date(self, t_start, t_end):
        l = 0
        r = self.n_tot
        while (l<r):
            mid = (l + r) // 2
            if self.data[mid]['date'] >= t_start:
                r = mid
            else:
                l = mid + 1
        
        idx_start = l
        idx_end = l
        while idx_end < self.n_tot and self.data[idx_end]['date'] < t_end:
            idx_end += 1

        self.idx_start = idx_start
        self.idx_end = idx_end
        
        return idx_start, idx_end
        

    def __len__(self):
        return self.idx_end - self.idx_start

    def __getitem__(self, index):
        idx = self.idx_start + index

        return self.embedding_headline[idx], self.embedding_short_description[idx], self.labels[self.data[idx]['category']]
    def get_category(self, index):
        idx = self.idx_start + index

        return self.data[idx]['category']




if __name__ == '__main__':
    start  = '2012-02-01'
    end = '2017-12-31'

    dataset = HuffPost(format_check=True, t_start = start, t_end = end)

    #dataset.preprocess_RobertaBase()
    dataset.load_embedding_RobertaBase()
    print (dataset[0][0].shape)
    print (dataset[0][1].shape)

    tot = 0

    for year in range(2012, 2022):
        for month in range(1, 13):
            t_start = '%04d-%02d-01'%(year, month)
            t_end = '%04d-%02d-31'%(year, month)

            if (t_start < start or t_end > end):
                continue
                
            idx_range = dataset.set_range_date(t_start, t_end)
            tot += idx_range[1] - idx_range[0]
            print (t_start, t_end, idx_range[1] - idx_range[0])
    print ('tot_remain', tot)


    dataset.set_range_date(start, end)
    dataset.format_check()
