import torch.utils.data as data_
import torch
import argparse
import common_io
import numpy as np
import sys
import time
import os
import math
import random
import json

from .tool import load_arguments



class TableDataset(torch.utils.data.IterableDataset):
    def __init__(self, local_args, args, selected_cols, mode=0,
                 slice_id=int(os.environ.get('RANK', 0)), 
                 slice_count=int(os.environ.get('WORLD_SIZE', 1)),
                 train=True,
                 row_count=False,
                 if_predict=False
                 ):

        self.args = args
        self.if_predict = if_predict
        self.max_len = args['max_len']
        self.info_onehot_cnt = args['info_element_cnt'] + 1
        self.voiceTimes_onehot_cnt = args['voiceTimes_cnt']
        self.action_onehot_cnt = args["action_element_cnt"]
        self.predict_max_len = args["predict_max_len"]
        self.max_voice_times = args["voiceTimes_cnt"]
        self.cloud_user_feature_seg_num = args["cloud_user_feature_seg_num"]
        self.sub_mean = args["cur_to_seg_mean"]
        self.sub_std = args["cur_to_seg_std"]       
        self.selected_cols = selected_cols
        self.table_path = local_args.tables.split(',')[mode]

        print('slice_id', slice_id, slice_count)

        reader = common_io.table.TableReader(self.table_path,slice_id=slice_id,slice_count=slice_count,num_threads=0)

        if row_count:
            self.row_count = row_count
        else:
            self.row_count = reader.get_row_count()
        self.start_pos = reader.start_pos
        self.end_pos = reader.end_pos
        reader.close()
        super(TableDataset, self).__init__()

    def __len__(self):
        return self.row_count

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            worker_id = 0
            num_workers = 1
        else:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers

        table_start, table_end = self._get_slice_range(self.row_count, worker_id, num_workers, self.start_pos)
        table_path = "{}?start={}&end={}".format(self.table_path, table_start, table_end)

        def table_data_iterator():

            reader = common_io.table.TableReader(table_path,selected_cols=self.selected_cols, num_threads=1, capacity=1024)
            while True:
                try:
                    '''
                    Domain prior knowledge
                    '''

                except common_io.exception.OutOfRangeException:
                    reader.close()
                    break
                yield output

        return table_data_iterator()

    def one_hot(self, labels, num):
        seq_len = labels.size(0)
        onehot = torch.LongTensor(np.eye(num)[labels.reshape(-1)]).reshape(seq_len, -1)
        return onehot

    def random_token(self, output, used_seq_len):
        if self.if_predict and self.predict_max_len>self.max_len:
            output["random_mask"] = torch.zeros(self.predict_max_len)
        else:
            output["random_mask"] = torch.zeros(self.max_len)
        if not self.if_predict:
            output["play_trigger"][mask_index] = torch.FloatTensor([-1])

            mask_num = used_seq_len

            mask_indexs = [i for i in range(mask_num)]
            for i in range(0, self.max_len):
                if i in mask_indexs:
                    output["random_mask"][i] = 1
        else: 
            mask_index = used_seq_len-1

            output["random_mask"][mask_index] = 1
            output["padding_mask"][mask_index+1:] = 0

            output["play_info"][mask_index] = torch.ones(self.info_onehot_cnt) * -1
            output["play_voiceTimes"][mask_index] = torch.ones(self.voiceTimes_onehot_cnt) * -1
            output["play_action"][mask_index] = torch.ones(self.action_onehot_cnt) * -1
            output["play_trigger"][mask_index] = torch.FloatTensor([-1])
            
        return output

    def _get_slice_range(self, row_count, worker_id, num_workers, baseline=0):
        size = int(row_count / num_workers)
        split_point = row_count % num_workers
        if worker_id < split_point:
            start = worker_id * (size + 1) + baseline
            end = start + (size + 1)
        else:
            start = split_point * (size + 1) + (worker_id - split_point) * size + baseline
            end = start + size
        return start, end

    def len(self):
        return self.row_count
        
    def feature_v5_to_v4(self, v5_feature):
        '''
        Domain prior knowledge
        '''
        return v4_feature


class LocalTableDataset(torch.utils.data.IterableDataset):
    def __init__(self, args, lines, if_predict = False):
        self.args = args
        self.if_predict = if_predict
        self.max_len = args['max_len']
        self.info_onehot_cnt = args['info_element_cnt'] + 1
        self.voiceTimes_onehot_cnt = args['voiceTimes_cnt']
        self.action_onehot_cnt = args["action_element_cnt"]
        self.predict_max_len = args["predict_max_len"]
        self.max_voice_times = args["voiceTimes_cnt"]
        self.lines = lines
        self.cloud_user_feature_seg_num = args["cloud_user_feature_seg_num"]
        self.sub_mean = args["cur_to_seg_mean"]
        self.sub_std = args["cur_to_seg_std"]


        super(LocalTableDataset, self).__init__()

    def __len__(self):
        return len(self.lines)

    def random_token(self, output, used_seq_len):
        if self.predict_max_len>self.max_len:
            output["random_mask"] = torch.zeros(self.predict_max_len)
        else:
            output["random_mask"] = torch.zeros(self.max_len)
        if not self.if_predict:
            mask_num = used_seq_len

            mask_indexs = [i for i in range(mask_num)]
            for i in range(0, self.max_len):
                if i in mask_indexs:
                    output["random_mask"][i] = 1
        else: 
            mask_index = used_seq_len-1

            output["random_mask"][mask_index] = 1
            output["padding_mask"][mask_index+1:] = 0

            output["play_info"][mask_index] = torch.ones(self.info_onehot_cnt) * -1
            output["play_voiceTimes"][mask_index] = torch.ones(self.voiceTimes_onehot_cnt) * -1
            output["play_action"][mask_index] = torch.ones(self.action_onehot_cnt) * -1
            output["play_trigger"][mask_index] = torch.FloatTensor([-1])
            
        return output

    def __iter__(self):
        def table_data_iterator():
            i = 0
            while True:
                try:
                    '''
                    Domain prior knowledge
                    '''

                except:
                    break
                yield output
        return table_data_iterator()

    def feature_v5_to_v4(self, v5_feature):
        '''
        Domain prior knowledge
        '''
        return v4_feature
    
    def one_hot(self, labels, num):
        seq_len = labels.size(0)
        onehot = torch.LongTensor(np.eye(num)[labels.reshape(-1)]).reshape(seq_len, -1)
        return onehot

    def len(self):
        return len(self.lines)
