############################################################
#
# utility.py
# utility functions for trading-attacks
# February 2020
#
############################################################

import os
from matplotlib import pyplot as plt
from matplotlib import colors
import datetime
import torch
import _pickle as pickle
import numpy as np
import scipy.sparse as sp
import multiprocessing as mp
from multiprocessing import Process
import csv

my_punc = '!"#$%&\'()*+/;<=>?@[\\]^`{|}~'
plt.rcParams["font.family"] = "Times New Roman"


def now():
    """ Helper for printing timestamps. """
    return datetime.datetime.now().strftime("%Y%m%d %H:%M:%S")


def to_log_file(out_dict, out_dir, log_name="log.txt"):
    """ Write lines to a log file. """

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    fname = os.path.join(out_dir, log_name)
    with open(fname, 'a') as f:
        f.write(str(now()) + " " + str(out_dict) + "\n")
    print('logging done in ' + out_dir + ',' + ' in ' + log_name)


def to_results_table(stats, out_dir, log_name="results.csv"):
    """ Save results to a table. """
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    fname = os.path.join(out_dir, log_name)
    try:
        with open(fname, 'r') as f:
            pass
    except:
        with open(fname, 'w') as f:
            fieldnames = list(stats.keys())
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
    with open(fname, 'a') as f:
        fieldnames = list(stats.keys())
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writerow(stats)
    print('results logged in  ' + out_dir + ',' + ' in ' + log_name)


def plot_data(net, smoother_net, dataloader, out_dir, pltname, device='cpu'):
    """ Plots one batch of data"""

    net = net.to(device)
    for batch_idx, (inputs, targets, dataset_idx) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        break

    signals = smoother_net(inputs).detach().cpu().numpy()
    inputs = inputs.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    outputs = outputs.detach().cpu().numpy()
    times = inputs[:, :, 0]

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    for idx in range(inputs.shape[0]-1):
        ax.plot(times[idx, :], signals[idx, :], 'k')
        ax.plot(dataloader.dataset.get_time_of_label(dataset_idx[idx]).cpu().numpy(), targets[idx], 'o', color='b',
                markersize=6)
        ax.plot(dataloader.dataset.get_time_of_label(dataset_idx[idx]).cpu().numpy(), outputs[idx], 'xr', markersize=9)
    fig.savefig(os.path.join(out_dir, pltname))


def save_fast_data_function(dataset_idx, dataset, save_dir, mode):
    """ Save out fast-data, a quick-to-perturb representation of the order book."""
    file_name = str(dataset_idx)
    current_data, _, _ = dataset[dataset_idx]
    shape = current_data.shape
    current_dict = {}
    transaction_list = dataset.get_transactions(dataset_idx)

    stuff_to_clear = set()

    for row_idx in range(shape[0]):
        transactions = transaction_list[row_idx]
        for col_idx in range(1, shape[1], 2):
            if col_idx % 4 == 1:
                price = -1.0 * current_data[row_idx, col_idx].item()
            else:
                price = current_data[row_idx, col_idx].item()

            # Fill the current_dict
            #   order of elements is indices, volumes, block starts, num times we have seen this price
            if price not in current_dict:
                current_dict[price] = [[], [], [0], 0]
                current_dict[price][0] = torch.tensor([row_idx, col_idx+1]).unsqueeze(0).long()
                current_dict[price][1] = torch.tensor([current_data[row_idx, col_idx+1].item()]).unsqueeze(0)
            else:
                current_dict[price][0] = torch.cat([current_dict[price][0], torch.tensor([row_idx, col_idx+1]).unsqueeze(0).long()])
                current_dict[price][1] = torch.cat([current_dict[price][1], torch.tensor([current_data[row_idx, col_idx+1].item()]).unsqueeze(0)])
            stuff_to_clear = stuff_to_clear | transactions
            if abs(price) in stuff_to_clear:
                stuff_to_clear.remove(abs(price))
                if current_dict[price][3] > 0:
                    current_dict[price][2].append(current_dict[price][3])   # block starts
            current_dict[price][3] += 1
                
    fast_indices = torch.cat([current_dict[price][0] for price in current_dict.keys()], dim=0)
    fast_volumes = torch.cat([current_dict[price][1] for price in current_dict.keys()], dim=0)

    count = 0
    for price in current_dict.keys():
        current_length = current_dict[price][0].shape[0]
        current_dict[price] = [i+count for i in current_dict[price][2]]
        count += current_length
        current_dict[price].append(count)

    fast_data = [current_dict, fast_indices, fast_volumes.squeeze(-1)]
    if not os.path.exists(os.path.join(save_dir, mode)):
        os.makedirs(os.path.join(save_dir, mode))
    path = os.path.join(save_dir, mode, file_name)
    with open(path, 'wb') as pickle_file:
        pickle.dump(fast_data, pickle_file)
    if dataset_idx % 100 == 0:
        print(now(), ' done with data', dataset_idx)


def save_fast_data_loop(dataset_idx_list, dataset, save_dir, mode):
    """ Function to call save_fast_data_function() from a loop for use with multiprocessing. """
    for idx in dataset_idx_list:
        save_fast_data_function(idx, dataset, save_dir, mode)


def save_fast_data_multiprocessing(dataset, save_dir, mode):
    """ For each element of the dataset, we create a dict and two tensors.
    The dict contains a list of start/end indices corresponding to price levels in the two tensors.
    The first tensor is coordinates, and the second tensor is volumes. """

    dataset_len = len(dataset)
    cpu_count = mp.cpu_count()

    chunks = [range(i, dataset_len, cpu_count) for i in range(cpu_count)]

    procs = []
    for chunk in chunks:
        proc = Process(target=save_fast_data_loop, args=(chunk, dataset, save_dir, mode))
        procs.append(proc)
        proc.start()
    for proc in procs:
        proc.join()


def adjust_learning_rate(optimizer, epoch, args):
    """ Function to adjust the learning rate according to a pre-determined schedule. """
    if epoch in args.lr_schedule:
        print('lr drop: ', args.lr, ' --> ', args.lr*args.lr_factor)
        args.lr *= args.lr_factor
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr


class MeanSubtract(torch.nn.Module):
    """ Mean subtraction layer. """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x - torch.mean(x, dim=1).unsqueeze(-1)


class STDDivide(torch.nn.Module):
    """ Standard deviation division layer. """
    def __init__(self, std=1.0):
        super().__init__()
        self.std=std

    def forward(self, x):
        return self.std*x/torch.std(x, dim=1).unsqueeze(-1)


class SizeWeightedAverage(torch.nn.Module):
    """ Size-weighted-average layer. """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        price = x[..., 1::2]
        size = x[..., 2::2]
        size_weighted_avg = torch.sum(price * size, dim=-1) / torch.sum(size, dim=-1)
        return size_weighted_avg


class LinearSmoothing(torch.nn.Module):
    """ Linear smoothing layer, for creating a moving average"""
    def __init__(self, input_size, coeff=0.9, window=100):
        super().__init__()
        self.input_size = input_size
        coeffs = [coeff ** i for i in range(window)]
        e = np.ones((1, input_size)).reshape(input_size)
        data = np.array([c * e for c in coeffs])
        diags = -1 * np.arange(len(coeffs))
        filter_array = sp.spdiags(data, diags, input_size, input_size).toarray()
        filter_array = filter_array / filter_array.sum(axis=1, keepdims=True)
        self.filter = torch.nn.Parameter(torch.as_tensor(filter_array).float(), requires_grad=False)

    def forward(self, x):
        # x comes in with batch index first, so we permute the indices
        x = x[:, :self.input_size]
        x = x.permute(1, 0)
        out = (self.filter @ x).permute(1, 0)
        return out


def uninterleave(inputs):
    """ Un-interleave buys and sells in the orderbook (used for plotting). """
    inputs = inputs.cpu().numpy()
    asks = inputs[:, 2::4]
    bids = inputs[:, 4::4]
    bids = bids[:, ::-1]
    output_array = np.concatenate((bids, asks), axis=1)
    return output_array


def plot_orderbook(inputs, title=""):
    """ Function to plot the order book as a heat map showing size distribution. """
    # initialize figure and axes
    fig, [ax1, ax2, cax] = plt.subplots(1, 3, sharey=False, sharex=False, gridspec_kw={"width_ratios": [1, 1, 0.05]})
    im1 = ax1.imshow(inputs[::-1, :10], cmap='hot', interpolation='nearest', aspect='auto')
    im2 = ax2.imshow(inputs[::-1, 10:], cmap='hot', interpolation='nearest', aspect='auto')

    # turn off all tick marks
    ax1.set_yticks([])
    ax1.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xticks([])

    # normalize colors for the colorbar
    vmin = np.min(inputs)
    vmax = np.max(inputs)
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    im1.set_norm(norm)
    im2.set_norm(norm)
    ax1.set_ylabel('Time $\\longrightarrow$', fontsize=20)
    ax1.set_title('Buy Orders', fontsize=20, y=-0.09)
    ax2.set_title('Sell Orders', fontsize=20, y=-0.09)
    cbar = fig.colorbar(im1, cax=cax, ticks=[vmin, vmax])
    cax.set_title('Size', fontsize=16, y=-0.08)
    # fig.suptitle(title, fontsize=20, y=.95)
    if not os.path.exists('plots'):
        os.makedirs('plots')
    fig.savefig(os.path.join('plots', title+'.pdf'))
    # plt.tight_layout()    
    # plt.show()
