############################################################
#
# test_model_visualize.py
# main python file for training a value predictor
# August 2019
#
############################################################

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import argparse
from utility import now, to_results_table, SizeWeightedAverage, MeanSubtract, STDDivide, save_fast_data_multiprocessing
from attack_module import attacker_costReg, volumePropagationRandom
from dataset_module import OrderbookDataset
import os
from models.linearpredictor import LinearPredictor
from models.mlp import MLP
from models.temporalconvnet import TemporalConvNet
from models.lstm import LSTM, Bidirectional_LSTM
import copy


def test(net, net_transfer, testloader, dataset, fast_path, criterion, attack_type, args, device='cuda', test_batches=50):

    total_loss = 0
    correct = 0
    clean_correct = 0
    fooled = 0
    total_capital = 0
    total_detectability = 0
    total_cost = 0
    perturbed = 0
    total = 0
    confusion_matrix = torch.zeros((args.label_size, args.label_size))
    targets_balance = torch.zeros(args.label_size)
    running_perturbation_prop = torch.zeros_like(dataset[0][0]).unsqueeze(0).to(device)
    running_perturbation_unprop = torch.zeros_like(dataset[0][0]).unsqueeze(0).to(device)

    net.eval()
    net.to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets, dataset_idx) in enumerate(testloader):
            print('batch_idx', batch_idx)
            if batch_idx >= test_batches:
                break
            inputs, targets = inputs.to(device), targets.to(device)
            clean_inputs, clean_targets = copy.deepcopy(inputs), copy.deepcopy(targets)

            if attack_type == "adv":
                inputs, perturbation_tensor, cost, capital_required, detectability, inputs_unpropagated = attacker_costReg(net_transfer, inputs, targets,
                                                                                       dataset_idx, dataset, fast_path,
                                                                                       num_steps=200, step_size=40.0,
                                                                                       device=device, signed=False,
                                                                                       cost_coeff=args.cost_coeff,
                                                                                       rand_step_size=True,
                                                                                       detectability_coeff=args.detectability_coeff,
                                                                                       capital_coeff=args.capital_coeff,
                                                                                       criterion=criterion, visualize=True)

            elif attack_type == "random":
                inputs, perturbation_tensor, cost, capital_required, detectability = volumePropagationRandom(net, inputs, dataset_idx,
                                                                                              dataset, clamp_start=0.0, clamp_end=1.0,
                                                                                              open_path=fast_path,
                                                                                              device=device)
            elif attack_type == "universal":
                print("Universal attacks not yet implemented in test_model.py -- NOT ATTACKING!!!")
            else:
                cost = torch.tensor([0])
                capital_required = torch.tensor([0])
                detectability = torch.tensor([0])

            outputs = net(inputs)
            clean_outputs = net(clean_inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * len(targets)
            total += targets.size(0)

            # Compute accuracy
            _, pred = outputs.max(1)
            correct += pred.eq(targets).sum().item()
            _, clean_pred = clean_outputs.max(1)
            clean_correct += clean_pred.eq(targets).sum().item()
            fooled += (clean_pred.ne(pred)*(clean_pred.eq(targets))).sum().item()
            print('fooled', (clean_pred.ne(pred)*(clean_pred.eq(targets))).sum().item()/targets.size(0))
            print('capital_required', capital_required.item())
            print('detectability', detectability.item())

            if torch.norm(clean_inputs-inputs) > 0:
                perturbed += 1
                total_cost += cost.item() * (clean_pred.ne(pred)*(clean_pred.eq(targets))).sum().item()
                total_capital += capital_required.item() * (clean_pred.ne(pred)*(clean_pred.eq(targets))).sum().item()
                total_detectability += detectability.item() * (clean_pred.ne(pred)*(clean_pred.eq(targets))).sum().item()
                diff_prop = inputs-clean_inputs
                diff_unprop = inputs_unpropagated
                if batch_idx < 100:
                    if (clean_pred.ne(pred)*(clean_pred.eq(targets))).sum().item()/targets.size(0)>0.0:
                        was_fooled = '_fooled'
                    else:
                        was_fooled = '_not_fooled'
                    torch.save(diff_prop, os.path.join(args.output, str(dataset_idx.item())+was_fooled+'_target='+str(targets.item())+'_diffProp.pth'))
                    torch.save(diff_unprop, os.path.join(args.output, str(dataset_idx.item())+was_fooled+'_target='+str(targets.item())+'_diffUnprop.pth'))
                    torch.save(clean_inputs, os.path.join(args.output, str(dataset_idx.item())+was_fooled+'_target='+str(targets.item())+'_clean_inputs.pth'))
                    torch.save(inputs, os.path.join(args.output, str(dataset_idx.item())+was_fooled+'_target='+str(targets.item())+'_perturbed_inputs.pth'))
                running_perturbation_prop = diff_prop/float(batch_idx+1) + float(batch_idx)*running_perturbation_prop/float(batch_idx+1)
                running_perturbation_unprop = diff_unprop/float(batch_idx+1) + float(batch_idx)*running_perturbation_unprop/float(batch_idx+1)


            for i in range(args.label_size):
                targets_i = targets.eq(i)
                for j in range(args.label_size):
                    outputs_j = pred.eq(j)
                    confusion_matrix[i, j] += (targets_i*outputs_j).sum()

            for i in range(args.label_size):
                targets_balance[i] += targets.eq(i).sum()

        # Compute accuracy
        accuracy = 100 * correct / total
        if perturbed:
            avg_capital = total_capital / fooled
            avg_cost = total_cost / fooled
            avg_detectability = total_detectability / fooled
        else:
            avg_capital = 0
            avg_cost = 0
            avg_detectability = 0

        print("accuracy: ", accuracy)
        print("number perturbed: ", perturbed)
        print("fooled: ", fooled)
        print("avg capital: ", avg_capital)
        print("avg cost: ", avg_cost)
        print("avg detectability: ", avg_detectability)
        
    return total_loss / total, accuracy, fooled, perturbed, avg_capital, avg_cost, avg_detectability, running_perturbation_unprop, running_perturbation_prop


def main():
    print(now(), "test_model.py running...")

    # Argument parser
    parser = argparse.ArgumentParser(description='Python')
    parser.add_argument('--output', default='default_out/', type=str, help='output directory name')
    parser.add_argument('--model', default='linear', type=str, help='Which model?')
    parser.add_argument('--model_transfer', default='linear', type=str, help='Which model to use for transfer attacks?')
    parser.add_argument('--horizon', default=10.0, type=float, help='Horizon used in data_process (in seconds)')
    parser.add_argument('--history', default=60.0, type=float, help='Length of price history used as input.')
    parser.add_argument('--data_path', default='nokia_data', type=str, help='where is the data?')
    parser.add_argument('--load_path', default='', type=str, help='path for model to be loaded')
    parser.add_argument('--load_path_transfer', default='', type=str, help='path for model to be loaded for generating transfer attacks')
    parser.add_argument('--timestep', default=0.01, type=float, help='gap (in seconds) between rows')
    parser.add_argument('--cpu', action='store_true', help='use cpu')
    parser.add_argument('--plot_data', action='store_true', help='plot data?')
    parser.add_argument('--label_size', default=3, type=int, help='size of output from networks')
    parser.add_argument('--no_attack', action='store_true', help='dont run attacks')
    parser.add_argument('--no_clean', action='store_true', help='dont run attacks')
    parser.add_argument('--MLP_depth', default=3, type=int, help='depth of MLP')
    parser.add_argument('--MLP_width', default=5000, type=int, help='with of MLP')
    parser.add_argument('--capital_coeff', default=0.0, type=float, help='coefficient of capital penalty.')
    parser.add_argument('--cost_coeff', default=0.0, type=float, help='coefficient of cost penalty.')
    parser.add_argument('--detectability_coeff', default=0.0, type=float, help='coefficient of detectability penalty.')
    args = parser.parse_args()

    # Set device
    if args.cpu:
        device = 'cpu'
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Set smoothing network
    input_length = int(args.history/args.timestep)

    print('==> Getting data...')
    smoother_net_full = nn.Sequential(SizeWeightedAverage())
    smoother_net = nn.Sequential(SizeWeightedAverage(), MeanSubtract(), STDDivide(1.0))

    trainset = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon, args.label_size,
                                train=True, device='cpu', timestep=args.timestep, num_testing_days=4, attack=True)

    testset_full = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon,
                                    args.label_size,
                                    train=False, device='cpu', timestep=args.timestep, num_testing_days=4,
                                    threshold=trainset.threshold, attack=False)
    testloader_full = torch.utils.data.DataLoader(testset_full, batch_size=1000, shuffle=True)

    testset_attack = OrderbookDataset(args.data_path, smoother_net_full, args.history, args.horizon, args.label_size,
                                      train=False, device='cpu', timestep=args.timestep, num_testing_days=4,
                                      threshold=trainset.threshold, attack=True)
    testloader_attack = torch.utils.data.DataLoader(testset_attack, batch_size=1, shuffle=True)

    print('==> Building model...')
    if args.model == 'linear':
        predictor_net = LinearPredictor(input_length, args.label_size)
    elif args.model == 'MLP':
        predictor_net = MLP(num_inputs=input_length, num_outputs=args.label_size, width=args.MLP_width, depth=args.MLP_depth, BN=True)
    elif args.model == 'TemporalConvNet':
        predictor_net = TemporalConvNet(num_inputs=input_length, num_channels=[1000, 500, 100, args.label_size], kernel_size=2, dropout=0.2)
    elif args.model == 'LSTM':
        predictor_net = LSTM(num_outputs=args.label_size, input_size=1, hidden_size=1, num_layers=1, device=device)
    elif args.model == 'Bidirectional_LSTM':
        predictor_net = Bidirectional_LSTM(num_outputs=args.label_size, input_size=1, hidden_size=2, num_layers=1, device=device)
    else:
        print('Model not implemented.')
        return

    # Set criterion depending on regression vs. classification
    criterion = nn.MSELoss() if args.label_size == 1 else nn.CrossEntropyLoss()

    if args.load_path:
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.load_path)
        smoother_net.load_state_dict(checkpoint['smoother'])# DO WE REALLY NEED THIS?
        predictor_net.load_state_dict(checkpoint['predictor'])
    else:
        print('No model loaded.')
        return

    if args.load_path_transfer:
        print('==> Building transfer model...')
        if args.model_transfer == 'linear':
            predictor_net_transfer = LinearPredictor(input_length, args.label_size)
        elif args.model_transfer == 'MLP':
            predictor_net_transfer = MLP(num_inputs=input_length, num_outputs=args.label_size, width=8000, depth=5, BN=True)
        elif args.model_transfer == 'TemporalConvNet':
            predictor_net_transfer = TemporalConvNet(num_inputs=input_length, num_channels=[1000, 500, 100, args.label_size], kernel_size=2, dropout=0.2)
        elif args.model_transfer == 'LSTM':
            predictor_net_transfer = LSTM(num_outputs=args.label_size, input_size=1, hidden_size=100, num_layers=3, device=device)
        elif args.model_transfer == 'Bidirectional_LSTM':
            predictor_net_transfer = Bidirectional_LSTM(num_outputs=args.label_size, input_size=1, hidden_size=2, num_layers=1, device=device)
        else:
            print('Transfer model not implemented.')
            return
        checkpoint = torch.load(args.load_path_transfer)
        predictor_net_transfer.load_state_dict(checkpoint['predictor'])
    else:
        predictor_net_transfer = predictor_net

    # Initialize the whole model and move to device and set batch size
    net = nn.Sequential(smoother_net, predictor_net)
    net_transfer = nn.Sequential(smoother_net, predictor_net_transfer)
    smoother_net = smoother_net.to(device)
    net = net.to(device)
    net_transfer = net_transfer.to(device)
    

    fast_path = os.path.join(args.data_path, 'fast_data')
    if not args.no_attack:
        # fast data generation
        if not os.path.isdir(fast_path):
            print("Making fast data...")
            os.makedirs(fast_path)
            save_fast_data_multiprocessing(trainset, save_dir=fast_path, mode='train')
            save_fast_data_multiprocessing(testset_attack, save_dir=fast_path, mode='test')
        else:
            print("Fast data found, not re-generating.")

    test_loss, test_acc, test_fooled, test_perturbed, test_capital, test_cost, test_detectability = False, False, False, False, False, False, False
    clean_loss, clean_acc, clean_fooled, clean_perturbed, clean_capital, clean_cost, test_detectability = False, False, False, False, False, False, False
    rand_loss, rand_acc, rand_fooled, rand_perturbed, rand_capital, rand_cost, rand_detectability = False, False, False, False, False, False, False
    print(now(), '\tTesting adversarial attack...')
    adv_loss, adv_acc, adv_fooled, adv_perturbed, adv_capital, adv_cost, adv_detectability, running_perturbation_unprop, running_perturbation_prop = test(net, net_transfer, testloader_attack,
                                                                                   testset_attack, fast_path, criterion,
                                                                                   "adv", args, device=device,
                                                                                   test_batches=5000)
    torch.save(os.path.join(running_perturbation_unprop, args.output, 'running_perturbation_unprop.pth'))
    torch.save(os.path.join(running_perturbation_prop, args.output, 'running_perturbation_prop.pth'))

    train_baseline = 100.0*torch.sum(testset_attack.labels==trainset.get_most_populous())/testset_attack.labels.shape[0]
    train_baseline = train_baseline.item()
    test_baseline = 100.0*torch.sum(testset_attack.labels==testset_full.get_most_populous())/testset_attack.labels.shape[0]
    test_baseline = test_baseline.item()
    stats = {'chkpt': args.load_path,
             'model': args.model,
             'dataset': args.data_path,
             'train baseline': train_baseline,
             'test baseline': test_baseline,
             'test acc': test_acc,
             'clean acc': clean_acc,
             'adv acc': adv_acc,
             'adv perturbed': adv_perturbed,
             'adv fooled': adv_fooled,
             'adv cost': adv_cost,
             'adv capital': adv_capital,
             'adv detectability': adv_detectability,
             'rand acc': rand_acc,
             'rand perturbed': rand_perturbed,
             'rand fooled': rand_fooled,
             'rand cost': rand_cost,
             'rand capital': rand_capital,
             'rand detectability': rand_detectability}

    to_results_table(stats, args.output)
    print(now(), '\tDone testing.\n\n')


if __name__ == '__main__':
    main()
