############################################################
#
# dataset_module.py
# dataset class definitions for trading-attacks
# January 2020
#
############################################################

import os
import torch
import _pickle as pickle
from torch.utils.data import Dataset
import glob
from tqdm import tqdm


class OrderbookDataset(Dataset):
    def __init__(self, root_dir, smoother_net, history, horizon, label_size, train=True, device='cuda', timestep=0.01,
                 num_testing_days=1, threshold=None, attack=False):

        self.history = history
        self.device = device
        self.horizon = horizon
        self.timestep = timestep
        self.root_dir = root_dir
        self.train = train
        self.label_size = label_size
        self.data = []
        self.start_idx_list = []
        self.smoother_net = smoother_net
        self.normalization_mean = False
        self.normalization_std = False
        self.input_length = int(history / timestep)
        self.horizon_idx = int((history + horizon) / timestep - 1)
        self.threshold = threshold
        self.num_testing_days = num_testing_days

        paths = sorted(glob.glob(os.path.join(root_dir, 'long_data', "*")))
        num_days = len(paths)
        train_days = num_days - num_testing_days
        if self.train:
            paths = paths[0:train_days]
        else:
            paths = paths[train_days:]

        len_so_far = 0
        for fh in paths:
            current_day = torch.load(fh)
            current_day_len = current_day.shape[0]
            self.data.append(current_day)
            self.start_idx_list.extend(list(range(len_so_far, len_so_far + current_day_len - self.horizon_idx - 1)))
            len_so_far += current_day_len

        self.data = torch.cat(self.data)

        # There is no indication of the smoother net used here, must be changed if smoothing is used.
        label_file_name = os.path.join(root_dir, 'labels_horizon=' + str(self.horizon) + '_history='
                                       + str(self.history) + '_testing_days=' + str(num_testing_days)
                                       + '_label_size=' + str(self.label_size))
        label_file_name = label_file_name + '_train' if self.train else label_file_name + '_test'

        if not os.path.isfile(label_file_name):
            self.save_labels(label_file_name)
        else:
            self.load_labels(label_file_name)

        if attack:
            self.load_transactions()
            self.start_idx_list = self.start_idx_list[::int(15/timestep)]
            self.labels = self.labels[::int(15/timestep)]

    def save_labels(self, label_file_name):
        price_predictions = []
        last_prices = []
        for start in tqdm(self.start_idx_list):
            full_data = self.data[start:start + self.horizon_idx + 1].to(self.device)
            smoothed_signal = self.smoother_net(full_data.unsqueeze(0)).to('cpu')
            price_predictions.append(smoothed_signal[0, self.horizon_idx].unsqueeze(0).item())
            last_prices.append(smoothed_signal[0, self.input_length].unsqueeze(0).item())

        price_predictions = torch.tensor(price_predictions)
        last_prices = torch.tensor(last_prices)

        assert(price_predictions.shape == last_prices.shape)
        price_changes = price_predictions - last_prices

        if self.label_size == 1:
            self.labels = price_changes
        else:
            self.labels = self.get_classification_labels(price_changes)  
        if self.label_size == 3:            
            with open(label_file_name + '_threshold', 'wb') as fh:
                pickle.dump(self.threshold, fh) 
        with open(label_file_name, 'wb') as fh:
            pickle.dump(self.labels, fh)

    def load_labels(self, label_file_name):
        with open(label_file_name, 'rb') as fh:
            self.labels = pickle.load(fh)
        if self.label_size == 3:        
            with open(label_file_name + '_threshold', 'rb') as fh:
                self.threshold = pickle.load(fh)

    def get_classification_labels(self, price_changes):
        if self.label_size == 2:
            return (price_changes > 0).long()
        elif self.label_size == 3:
            sorted_changes = torch.sort(torch.abs(price_changes))[0]
            if self.train:
                self.threshold = sorted_changes[len(sorted_changes) // 3].item()
            downs = price_changes < -self.threshold
            flats = torch.abs(price_changes) < self.threshold
            ups = price_changes > self.threshold
            return (0 * downs + 1 * flats + 2 * ups).long()
        else:
            print("Classification with this many classes, not handeled, number of classses: ", self.label_size)
            return

    def load_transactions(self):
        self.transactions_list = []
        self.running_price = []
        paths = sorted(glob.glob(os.path.join(self.root_dir, 'transactions', "????????")))
        num_days = len(paths)
        train_days = num_days - self.num_testing_days
        if self.train:
            paths = paths[0:train_days]
        else:
            paths = paths[train_days:]

        for fh in paths:
            with open(fh, 'rb') as trans_file:
                current_day = pickle.load(trans_file)
            self.transactions_list.extend(current_day)
            current_day_running = torch.load(fh + 'running')
            self.running_price.append(current_day_running)

        self.running_price = torch.cat(self.running_price)

    def get_transactions(self, idx):
        start = self.start_idx_list[idx]
        end = start + self.input_length
        return self.transactions_list[start:end]

    def get_running_price(self, idx):
        start = self.start_idx_list[idx]
        end = start + self.input_length
        return self.running_price[start:end]

    def get_most_populous(self):
        return torch.mode(self.labels.squeeze())[0]

    def __len__(self):
        return len(self.start_idx_list)

    def __getitem__(self, idx):
        start = self.start_idx_list[idx]
        end = start + self.input_length
        return self.data[start:end], self.labels[idx], idx
