############################################################
#
# train_model_adversarially.py
# main python file for training HFT models
# February 2020
#
############################################################

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import argparse
from utility import now, to_log_file, SizeWeightedAverage, adjust_learning_rate, MeanSubtract
from utility import STDDivide, save_fast_data_multiprocessing
from dataset_module import OrderbookDataset
from attack_module import attacker_costReg
import os
from tqdm import tqdm
from models.linearpredictor import LinearPredictor
from models.mlp import MLP
from models.lstm import LSTM
import copy


def train(net, optimizer, criterion, trainloader, iteration, args, fast_path, device='cuda', batches_per_update=1):
    """ Function to execute some number of iterations of training. """

    # Set net to train and zeros stats and adjust learning rate
    net.train()
    net = net.to(device)
    correct = 0
    total = 0
    total_loss = 0
    adjust_learning_rate(optimizer, iteration + 1, args)

    for batch_idx, (inputs, targets, dataset_idx) in enumerate(trainloader):

        # break after batches per update
        if batch_idx + 1 > batches_per_update:
            break

        # get inputs and targets, send to device and save a copy for comparing with adversarial attacks
        inputs, targets = inputs.to(device), targets.to(device)
        clean_inputs = copy.deepcopy(inputs)

        # if adversarially training, get perturbed inputs
        if args.adversarial:
            if args.model != 'LSTM':
                net = net.eval()
            for i in range(inputs.shape[0]):
                # loop through batch and adversarially attack each image
                inputs[i], _, _, _, _ = attacker_costReg(net, inputs[i].unsqueeze(0), targets[i].unsqueeze(0),
                                                         dataset_idx[i].unsqueeze(0), trainloader.dataset, fast_path,
                                                         num_steps=20, step_size=80.0,
                                                         device='cuda', signed=False, cost_coeff=0.0,
                                                         rand_step_size=True, detectability_coeff=0.0,
                                                         capital_coeff=0.0, criterion=criterion, cap_bound=100000)
            net = net.train()

        # Find perturbed inputs:
        diff_inputs = torch.abs(inputs - clean_inputs)
        idx_of_perturbs = torch.sum(torch.sum(diff_inputs, dim=1), dim=1) > 0
        inputs, targets, dataset_idx = inputs[idx_of_perturbs], targets[idx_of_perturbs], dataset_idx[idx_of_perturbs]

        # if nothing was perturbed, skip the rest of the loop on this iteration
        if len(dataset_idx) == 0:
            continue

        # forward pass, backward pass, and loss and accuracy measurements
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        total_loss += loss.item() * len(targets)
        total += targets.size(0)
        _, pred = outputs.max(1)
        correct += pred.eq(targets).sum().item()

    optimizer.step()
    optimizer.zero_grad()

    # Compute accuracy (handle cases with no data in the batch -- this arises in adversarial training when nothing
    # gets perturbed in a batch).
    if total > 0:
        accuracy = 100 * correct / total
        loss_to_return = total_loss / total
    else:
        accuracy = 0
        loss_to_return = 0

    return loss_to_return, accuracy


def test(net, testloader, criterion, args, device='cuda', test_batches=50):
    """ Function to test performance of the model. """

    # initialize all stats
    total_loss = 0
    correct = 0
    total = 0
    confusion_matrix = torch.zeros((args.label_size, args.label_size))
    targets_balance = torch.zeros(args.label_size)
    net.eval()
    net.to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets, dataset_idx) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = net(inputs)
            loss = criterion(outputs.squeeze(), targets)
            total_loss += loss.item() * len(targets)
            total += targets.size(0)

            # Compute accuracy and fill confusion matrix
            _, pred = outputs.max(1)
            correct += pred.eq(targets).sum().item()
            for i in range(args.label_size):
                targets_i = targets.eq(i)
                for j in range(args.label_size):
                    outputs_j = pred.eq(j)
                    confusion_matrix[i, j] += (targets_i * outputs_j).sum()

            for i in range(args.label_size):
                targets_balance[i] += targets.eq(i).sum()

            if batch_idx >= test_batches - 1:
                break

        # Compute accuracy
        if outputs.shape[1] > 1:
            accuracy = 100 * correct / total
            # print('confusion matrix: ', confusion_matrix)
            # print('targets balance: ', targets_balance)

    return total_loss / total, accuracy, confusion_matrix


def main():
    print(now(), "train.py running...")

    # Argument parser
    parser = argparse.ArgumentParser(description='Python')
    parser.add_argument('--iterations', default=5000, type=int, help='How many batches?')
    parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('--lr_schedule', nargs='+', default=[50, 500, 1000, 2000, 3000, 4000, 4500], type=int,
                        help='how often to decrease lr')
    parser.add_argument('--lr_factor', default=0.5, type=float, help='factor by which to decrease lr')
    parser.add_argument('--output', default='default_out/', type=str, help='output directory name')
    parser.add_argument('--checkpoint', default='checkpoint', type=str, help='Where to save checkpoitns.')
    parser.add_argument('--model', default='linear', type=str, help='Which model?')
    parser.add_argument('--horizon', default=10.0, type=float, help='Horizon used in data_process (in seconds)')
    parser.add_argument('--history', default=60.0, type=float, help='Length of price history used as input.')
    parser.add_argument('--num_testing_days', default=4, type=int, help='How many days in the test set?')
    parser.add_argument('--batch_size', default=2500, type=int, help='Mini-batch size')
    parser.add_argument('--data_path', default='nokia_data', type=str, help='where is the data?')
    parser.add_argument('--val_period', default=500, type=int, help='How often to test.')
    parser.add_argument('--save_period', default=5000000000, type=int, help='How often to save?')
    parser.add_argument('--load_path', default='', type=str, help='path for model to be loaded')
    parser.add_argument('--timestep', default=0.01, type=float, help='gap (in seconds) between rows')
    parser.add_argument('--cpu', action='store_true', help='use cpu')
    parser.add_argument('--label_size', default=3, type=int, help='size of output from networks')
    parser.add_argument('--adversarial', action='store_true', help='adversarially train')
    parser.add_argument('--MLP_depth', default=5, type=int, help='depth of MLP')
    parser.add_argument('--MLP_width', default=8000, type=int, help='with of MLP')
    parser.add_argument('--lstm_hidden_size', default=100, type=int, help='hidden size of lstm')
    parser.add_argument('--lstm_layers', default=3, type=int, help='num layers of lstm')
    parser.add_argument('--batches_per_update', default=1, type=int, help='batches per gradient update')
    parser.add_argument('--optimizer', default='sgd', type=str, help='which optimizer?')
    args = parser.parse_args()

    # log args
    to_log_file(args, args.output, 'log.txt')

    # Set device
    if args.cpu:
        device = 'cpu'
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Set smoothing network
    input_length = int(args.history / args.timestep)
    smoother_net_full = nn.Sequential(SizeWeightedAverage())
    smoother_net = nn.Sequential(SizeWeightedAverage(), MeanSubtract(), STDDivide(1.0))
    smoother_net = smoother_net.to(device)

    # Get datasets for natural training, use OrderbookDataset
    trainset = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon, args.label_size,
                                train=True, device='cpu', timestep=args.timestep,
                                num_testing_days=args.num_testing_days, attack=args.adversarial)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)

    testset = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon, args.label_size,
                               train=False, device='cpu', timestep=args.timestep,
                               num_testing_days=args.num_testing_days,
                               threshold=trainset.threshold, attack=args.adversarial)
    testloader = torch.utils.data.DataLoader(testset, batch_size=250, shuffle=True)

    # initialize model
    if args.model == 'linear':
        predictor_net = LinearPredictor(input_length, args.label_size)
    elif args.model == 'MLP':
        predictor_net = MLP(num_inputs=input_length, num_outputs=args.label_size, width=args.MLP_width,
                            depth=args.MLP_depth, BN=True)
    elif args.model == 'LSTM':
        predictor_net = LSTM(num_outputs=args.label_size, input_size=1, hidden_size=args.lstm_hidden_size,
                             num_layers=args.lstm_layers, device=device)
    else:
        print('Model not implemented.')
        return

    # Set criterion depending on regression vs. classification
    criterion = nn.MSELoss() if args.label_size == 1 else nn.CrossEntropyLoss()
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(predictor_net.parameters(), lr=args.lr, momentum=0.9, weight_decay=2e-4)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(predictor_net.parameters(), lr=args.lr, weight_decay=2e-4)
    else:
        print("Optimizer not yet implemented.")
        return

    if args.load_path:
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.load_path)
        smoother_net.load_state_dict(checkpoint['smoother'])
        predictor_net.load_state_dict(checkpoint['predictor'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
        start_iteration = checkpoint['iteration'] + 1
        if start_iteration >= args.iterations:
            train_loss = 0
    else:
        start_iteration = 0

    # Initialize the whole model and move to device
    net = nn.Sequential(smoother_net, predictor_net)
    net = net.to(device)

    if args.adversarial:
        # fast data generation
        fast_path = os.path.join(args.data_path, 'fast_data')
        if not os.path.isdir(fast_path):
            print("Making fast data...")
            os.makedirs(fast_path)
            save_fast_data_multiprocessing(trainset, save_dir=fast_path, mode='train')
            save_fast_data_multiprocessing(testset, save_dir=fast_path, mode='test')
        else:
            print("Fast data found, not re-generating.")
    else:
        fast_path = ''

    # training loop
    print('\n', now(), '\t==> Training...')
    training_losses = []
    for iteration in tqdm(range(start_iteration, args.iterations), leave=False):
        train_loss, train_acc = train(net, optimizer, criterion, trainloader, iteration, args, fast_path, device=device,
                                      batches_per_update=args.batches_per_update)
        training_losses.append(train_loss)
        if (iteration + 1) % 100 == 0:
            print('Train loss: ', train_loss, 'Train acc:', train_acc)

        # validate
        if (iteration + 1) % args.val_period == 0:
            test_loss, test_acc, confusion_matrix = test(net, testloader, criterion, args, device=device)
            results = {'chkpt': args.checkpoint,
                       'training loss': train_loss,
                       'test loss': test_loss,
                       'training acc': train_acc,
                       'test acc': test_acc,
                       'confusion matrix': confusion_matrix}
            to_log_file(results, args.output, 'log.txt')
            print('\nIteration: ', iteration)
            print('Test loss: ', test_loss, 'Test acc:', test_acc)
            print('Train loss: ', train_loss, 'Train acc:', train_acc)
            print('Confusion Matrix: ', confusion_matrix)

        # save on save period and at the end of training
        if (iteration + 1) % args.save_period == 0 or iteration == args.iterations - 1:
            state = {
                'predictor': predictor_net.state_dict(),
                'smoother': smoother_net.state_dict(),
                'iteration': iteration,
                'optimizer': optimizer.state_dict()
            }

            out_str = os.path.join(args.checkpoint, 'adv'+str(args.adversarial)+args.model + 'labels_horizon='
                                   + str(args.horizon)
                                   + '_history=' + str(args.history)
                                   + '_testing_days=' + str(args.num_testing_days)
                                   + '_label_size=' + str(args.label_size)
                                   + '_iteration=' + str(iteration) + '.t7')
            print('\nsaving model to: ', args.checkpoint, '\n')
            if not os.path.isdir(args.checkpoint):
                os.makedirs(args.checkpoint)
            torch.save(state, out_str)
    print(now(), '\tDone training.')

    print(now(), '\tTesting...')
    test_loss, test_acc, test_confusion_matrix = test(net, testloader, criterion, args, device=device)
    results = {'chkpt': args.checkpoint,
               'training loss': train_loss,
               'test loss': test_loss,
               'training acc': train_acc,
               'test acc': test_acc,
               'confusion matrix': test_confusion_matrix}
    to_log_file(results, args.output, 'log.txt')
    print(now(), '\tDone testing.\n\n')


if __name__ == '__main__':
    main()
