############################################################
#
# attack_universal.py
# main python file for training a value predictor
# August 2019
#
############################################################

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import argparse
from utility import now, to_log_file, plot_data, SizeWeightedAverage, LinearSmoothing, MeanSubtract, STDDivide, save_fast_data_multiprocessing
from attack_module import universalPerturbation_costReg, measure_universal_stats, propagate_universal_perturbation
from dataset_module import OrderbookDataset
import os
import sys
from tqdm import tqdm
from models.linearpredictor import LinearPredictor
from models.mlp import MLP
from models.temporalconvnet import TemporalConvNet
from models.lstm import LSTM, Bidirectional_LSTM


def test(net, net_transfer, testloader, dataset, fast_path, criterion, args, universal_perturbation, device='cuda', test_batches=5000):

    all_targets = []
    all_outputs = []
    total_loss = 0
    correct = 0
    total = 0
    confusion_matrix = torch.zeros((args.label_size, args.label_size))
    targets_balance = torch.zeros(args.label_size)
    originally_correct=0
    fooled=0
    fooled_cost = 0
    fooled_capital = 0
    fooled_proportion = 0
    net.eval()
    net.to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets, dataset_idx) in enumerate(testloader):
            orig_success = False
            if batch_idx >= test_batches:
                break
            # print(batch_idx)
            inputs, targets = inputs.to(device), targets.to(device)
            _, clean_pred = net(inputs).max(1)
            if clean_pred.eq(targets).sum().item() == 1:
                originally_correct += clean_pred.eq(targets).sum().item()
                orig_success = True
                print('working')

            cost, capital_required, detectability = measure_universal_stats(universal_perturbation, fast_path, dataset, dataset_idx, device)
            print('cost', cost.item())
            print('capital_required', capital_required.item())
            print('proportion of order book occupied', detectability.item())
            outputs = net(inputs+propagate_universal_perturbation(universal_perturbation, fast_path, dataset, dataset_idx, device))
            loss = criterion(outputs, targets)
            total_loss += loss.item() * len(targets)
            total += targets.size(0)
            _, new_pred = outputs.max(1)
            if orig_success:
                _, new_pred = outputs.max(1)
                if new_pred.ne(targets).sum().item()>0:
                    fooled += new_pred.ne(targets).sum().item()
                    fooled_cost += cost.item()
                    fooled_capital += capital_required.item()
                    fooled_proportion += detectability.item()

            # Compute accuracy
            if outputs.shape[1] > 1:
                _, pred = outputs.max(1)
                correct += pred.eq(targets).sum().item()
                for i in range(args.label_size):
                    targets_i = targets.eq(i)
                    for j in range(args.label_size):
                        outputs_j = pred.eq(j)
                        confusion_matrix[i,j] += (targets_i*outputs_j).sum()

                for i in range(args.label_size):
                    targets_balance[i] += targets.eq(i).sum()

            else:
                all_outputs.append(outputs)
                all_targets.append(targets)

        # Compute accuracy
        if outputs.shape[1] > 1:
            accuracy = 100 * correct / total
            if fooled:
                avg_cost = fooled_cost / fooled
                avg_capital = fooled_capital / fooled
                avg_detectability = fooled_proportion / fooled
            else:
                avg_cost = 0
                avg_capital = 0
                avg_detectability = 0
        else:
            all_outputs = torch.cat(all_outputs).squeeze()
            all_targets = torch.cat(all_targets).squeeze()
            residual = torch.sum((all_outputs - all_targets) ** 2)
            mean_err = torch.sum((torch.mean(all_targets) - all_targets) ** 2)
            accuracy = (1 - (residual / mean_err)).item()
        
    return total_loss / total, accuracy, confusion_matrix, originally_correct, fooled, avg_cost, avg_capital, avg_detectability


def main():
    print(now(), "attack_universal.py running...")

    # Argument parser
    parser = argparse.ArgumentParser(description='Python')
    parser.add_argument('--output', default='default_out/', type=str, help='output directory name')
    parser.add_argument('--model', default='linear', type=str, help='Which model?')
    parser.add_argument('--model_transfer', default='linear', type=str, help='Which model to use for transfer attacks?')
    parser.add_argument('--horizon', default=10.0, type=float, help='Horizon used in data_process (in seconds)')
    parser.add_argument('--history', default=60.0, type=float, help='Length of price history used as input.')
    parser.add_argument('--data_path', default='nokia_data', type=str, help='where is the data?')
    parser.add_argument('--smoothing_coeff', default=1.0, type=float, help='Smoothing coefficient in linear smoother.')
    parser.add_argument('--window', default=1, type=int, help='Window for smoothing.')
    parser.add_argument('--load_path', default='', type=str, help='path for model to be loaded')
    parser.add_argument('--load_path_transfer', default='', type=str, help='path for model to be loaded for generating transfer attacks')
    parser.add_argument('--load_universal_perturbation', default='', type=str, help='path for loading universal perturbation')
    parser.add_argument('--timestep', default=0.01, type=float, help='gap (in seconds) between rows')
    parser.add_argument('--cpu', action='store_true', help='use cpu')
    parser.add_argument('--plot_data', action='store_true', help='plot data?')
    parser.add_argument('--label_size', default=3, type=int, help='size of output from networks')
    parser.add_argument('--attack_batch_size', default=20, type=int, help='batch size for making attacks')
    parser.add_argument('--MLP_depth', default=5, type=int, help='depth of MLP')
    parser.add_argument('--MLP_width', default=8000, type=int, help='with of MLP')
    parser.add_argument('--save_path', default='debug.t7', type=str, help='path for model to be loaded')
    parser.add_argument('--capital_coeff', default=0.0, type=float, help='regularizer coeff')
    parser.add_argument('--detectability_coeff', default=0.0, type=float, help='regularizer coeff')
    parser.add_argument('--cost_coeff', default=0.0, type=float, help='regularizer coeff')
    parser.add_argument('--target', default=3, type=int, help='target.  3 means untargeted.')
    parser.add_argument('--outer_steps', default=0, type=int, help='outer steps in Univ pert computation.')

    args = parser.parse_args()
    if args.target == 3:
        args.target = None

    # log args
    #to_log_file(args, args.output, 'log.txt')

    # Set device
    if args.cpu:
        device = 'cpu'
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Set smoothing network
    total_input_length = int((args.history+args.horizon)/args.timestep)
    input_length = int(args.history/args.timestep)

    smoother_net_full = nn.Sequential(SizeWeightedAverage())
    smoother_net = nn.Sequential(SizeWeightedAverage(), MeanSubtract(), STDDivide(1.0))

    # Get datasets for natural training, use OrderbookDatasetFull

    trainset = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon, args.label_size,
                                train=True, device='cpu', timestep=args.timestep, num_testing_days=4, attack=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True)

    testset = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon, args.label_size,
                               train=False, device='cpu', timestep=args.timestep, num_testing_days=4,
                               threshold=trainset.threshold, attack=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)

    # Define model if output size is 1, we are doing regression and loss should be set accordingly in train() above,
    # if output size is 2 or 3 then this is a classification problem and loss should match that as well
    print('==> Building model...')
    if args.model == 'linear':
        predictor_net = LinearPredictor(input_length, args.label_size)
    elif args.model == 'MLP':
        predictor_net = MLP(num_inputs=input_length, num_outputs=args.label_size, width=args.MLP_width, depth=args.MLP_depth, BN=True)
    elif args.model == 'TemporalConvNet':
        predictor_net = TemporalConvNet(num_inputs=input_length, num_channels=[1000, 500, 100, args.label_size], kernel_size=2, dropout=0.2)
    elif args.model == 'LSTM':
        predictor_net = LSTM(num_outputs=args.label_size, input_size=1, hidden_size=100, num_layers=3, device=device)
    elif args.model == 'Bidirectional_LSTM':
        predictor_net = Bidirectional_LSTM(num_outputs=args.label_size, input_size=1, hidden_size=2, num_layers=1, device=device)
    else:
        print('Model not implemented.')
        return

    # Set criterion depending on regression vs. classification
    criterion = nn.MSELoss() if args.label_size == 1 else nn.CrossEntropyLoss()

    if args.load_path:
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.load_path)
        smoother_net.load_state_dict(checkpoint['smoother'])# DO WE REALLY NEED THIS?
        predictor_net.load_state_dict(checkpoint['predictor'])
    else:
        print('No model loaded.')
        return

    if args.load_path_transfer:
        print('==> Building transfer model...')
        if args.model_transfer == 'linear':
            predictor_net_transfer = LinearPredictor(input_length, args.label_size)
        elif args.model_transfer == 'MLP':
            predictor_net_transfer = MLP(num_inputs=input_length, num_outputs=args.label_size, width=args.MLP_width, depth=args.MLP_depth, BN=True)
        elif args.model_transfer == 'TemporalConvNet':
            predictor_net_transfer = TemporalConvNet(num_inputs=input_length, num_channels=[1000, 500, 100, args.label_size], kernel_size=2, dropout=0.2)
        elif args.model_transfer == 'LSTM':
            predictor_net_transfer = LSTM(num_outputs=args.label_size, input_size=1, hidden_size=100, num_layers=3, device=device)
        elif args.model_transfer == 'Bidirectional_LSTM':
            predictor_net_transfer = Bidirectional_LSTM(num_outputs=args.label_size, input_size=1, hidden_size=2, num_layers=1, device=device)
        else:
            print('Transfer model not implemented.')
            return
        checkpoint = torch.load(args.load_path_transfer)
        predictor_net_transfer.load_state_dict(checkpoint['predictor'])
    else:
        predictor_net_transfer = predictor_net

    # Initialize the whole model and move to device and set batch size
    net = nn.Sequential(smoother_net, predictor_net)
    net_transfer = nn.Sequential(smoother_net, predictor_net_transfer)
    smoother_net = smoother_net.to(device)
    net = net.to(device)
    net_transfer = net_transfer.to(device)
    
    # fast data generation
    fast_path = os.path.join(args.data_path, 'fast_data')
    if not os.path.isdir(fast_path):
        print("Making fast data...")
        os.makedirs(fast_path)
        save_fast_data_multiprocessing(trainset, save_dir=fast_path, mode='train')
        save_fast_data_multiprocessing(testset, save_dir=fast_path, mode='test')
    else:
        print("Fast data found, not re-generating.")
    
    print(now(), '==> Generating universal perturbation')
    if args.load_universal_perturbation:
        load_path = os.path.join(args.output, args.load_universal_perturbation)
    else:
        load_path = ''
    universal_perturbation = universalPerturbation_costReg(net_transfer, trainset, fast_path, attack_batch_size=args.attack_batch_size, outer_steps=args.outer_steps, device=device, detectability_coeff = args.detectability_coeff, cost_coeff = args.cost_coeff, capital_coeff=args.capital_coeff, criterion = criterion, load_path=load_path, save_path=os.path.join(args.output, args.save_path), target=args.target)


    print(now(), '\tTesting...')
    test_loss, test_acc, test_confusion_matrix, originally_correct, fooled, avg_cost, avg_capital, avg_detectability = test(net, net_transfer, testloader, testset, fast_path, criterion, args, universal_perturbation, device=device)
    print(now(), '\tDone testing.\n\n')
    
    print('test loss: ', test_loss)
    print('test_acc: ', test_acc)
    print('confusion_matrix: ', test_confusion_matrix)
    print('originally_correct', originally_correct)
    print('fooled', fooled)
    print('avg_cost: ', avg_cost) 
    print('avg_capital: ', avg_capital) 
    print('avg_detectability: ', avg_detectability)

if __name__ == '__main__':
    main()
