############################################################
#
# attack_module.py
# classes and functions for adversarial attack on trading
# programs
# February 2020
#
############################################################

import os
import torch
import _pickle as pickle
from torch.utils.data import Dataset
from tqdm import tqdm
from random import random
import copy


def dict_to_tensor(inputs_orig, fast_data):
    """ Replaces inputs orig with fast_data volumes. """
    inputs_orig[fast_data[1][:, 0], fast_data[1][:, 1]] = fast_data[2][:]


def inputs_to_perturbation_tensor(fast_data, perturbation_inputs, perturbation_tensor):
    """Converts perturbation_inputs to a perturbation tensor"""
    perturbation_tensor[:] = perturbation_inputs[fast_data[1][:, 0], fast_data[1][:, 1]]


class PropagateVolumePerturbation(torch.nn.Module):
    """ Propagates perturbations downstream. """

    def __init__(self):
        super(PropagateVolumePerturbation, self).__init__()

    def forward(self, fast_data, perturbation_tensor):
        fast_data_dict = copy.deepcopy(fast_data[0])
        fast_data_volumes = copy.deepcopy(fast_data[2])
        for key in fast_data_dict.keys():
            for i in range(len(fast_data_dict[key])-1):
                start = fast_data_dict[key][i]
                end = fast_data_dict[key][i+1]
                fast_data_volumes[start:end] += torch.cumsum(perturbation_tensor[start:end], dim=0)
        return [fast_data_dict, copy.deepcopy(fast_data[1]), fast_data_volumes]


class CalculateCost(torch.nn.Module):
    """ Computes net balance of transactions during inputs. """
    def __init__(self):
        super(CalculateCost, self).__init__()

    def forward(self, fast_data, perturbation_tensor, dataset, dataset_idx):
        cost = 0.0
        share_balance = 0.0
        base_dict = fast_data[0]
        end_price = dataset.get_running_price(dataset_idx)[-1]
        for key in base_dict.keys():
            num_chunks = len(base_dict[key])-2
            for i in range(num_chunks):
                start = base_dict[key][i]
                end = base_dict[key][i+1]
                sign_of_key = float(key)/abs(float(key))
                volumes_in_block = torch.sum(perturbation_tensor[start:end], dim=0)
                cost += float(key)*volumes_in_block
                share_balance += sign_of_key*volumes_in_block
        return cost-share_balance*end_price


class CalculateCapital(torch.nn.Module):
    """ Computes net balance of orders during inputs. """
    def __init__(self):
        super(CalculateCapital, self).__init__()

    def forward(self, fast_data, perturbation_tensor):
        capital = 0.0
        fast_dict = fast_data[0]
        for key in fast_dict.keys():
            start = fast_dict[key][0]
            end = fast_dict[key][-1]
            capital += abs(float(key))*torch.sum(perturbation_tensor[start:end], dim=0)
        return capital


class ProportionOccupied(torch.nn.Module):
    """ Computes proportion of the order book taken up by perturbation. """
    def __init__(self,fast_data):
        super(ProportionOccupied, self).__init__()
        self.orig_volume = torch.sum(fast_data[2])

    def forward(self, propagated):
        propagated_volume = torch.sum(propagated[2])
        return (propagated_volume-self.orig_volume)/(propagated_volume + 0.000000000001)


def open_fast_data(open_path, dataset, dataset_idx):
    """ Opens fast data from open_path at dataset index dataset_idx. """
    file_name = str(dataset_idx)
    if dataset.train:
        mode = 'train'
    else:
        mode = 'test'
    path = os.path.join(open_path, mode, file_name)
    with open(path, 'rb') as pickle_file:
        fast_data = pickle.load(pickle_file)
    return fast_data


def create_perturbation_tensor(fast_data, mode='zeros'):
    """ Creates tensor of volume perturbations in shape of fast_data[2]. """
    if mode == 'zeros':
        return torch.zeros_like(fast_data[2])
    elif mode == 'ones':
        return torch.ones_like(fast_data[2])


def get_step_size(step_size, rand=True):
    if rand:
        return 2*step_size*random()
    else:
        return step_size


def volumePropagationRandom(net, inputs, dataset_idx, dataset, clamp_start=0.0, clamp_end=10.0,
                            open_path='./data/fast_data', device='cuda'):
    """ Random noise attack on volume which propagates perturbations.  Limits volume of each entry's perturbation."""

    with torch.no_grad():
        fast_data = open_fast_data(open_path, dataset, dataset_idx[0].item())
        fast_data[1], fast_data[2] = fast_data[1].to(device), fast_data[2].to(device)
        perturbation_tensor = torch.randint_like(fast_data[2], low=int(clamp_start), high=int(clamp_end))
        propagator = PropagateVolumePerturbation().to(device)
        propagated = propagator(fast_data, perturbation_tensor)
        dict_to_tensor(inputs[0], propagated)
        cost_criterion = CalculateCost()
        cost = cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx)
        capital_calculator = CalculateCapital()
        capital_required = capital_calculator(fast_data, perturbation_tensor)
        detectability_criterion = ProportionOccupied(fast_data)
        detectability = detectability_criterion(propagated)
    return inputs, perturbation_tensor, cost, capital_required, detectability


def attacker_costReg(net, inputs, targets, dataset_idx, dataset, fast_path, num_steps=20, step_size=20.0,
                     device='cuda', signed=False, cost_coeff=0.0, rand_step_size=True,
                     detectability_coeff=0.0, capital_coeff=0.0, criterion=torch.nn.CrossEntropyLoss(),
                     target = None, rounding_threshold = 0.45, visualize=False, cap_bound=False):
    """smart attack on volume which propagates perturbations.  Limits volume of each entry's perturbation. """
    inputs_copy = copy.deepcopy(inputs)
    cost_criterion = CalculateCost()
    capital_criterion = CalculateCapital()
    with torch.no_grad():
        fast_data = open_fast_data(fast_path, dataset, dataset_idx[0].item())
        fast_data[1], fast_data[2] = fast_data[1].to(device), fast_data[2].to(device)
        detectability_criterion = ProportionOccupied(fast_data)
        perturbation_tensor = create_perturbation_tensor(fast_data).to(device).requires_grad_()

        propagator = PropagateVolumePerturbation().to(device)

        counter = 0.0

        for step in range(num_steps):
            with torch.enable_grad():
                rounded_perturbation_tensor = (perturbation_tensor+rounding_threshold).round().detach()
                rounded_propagated = propagator(fast_data, rounded_perturbation_tensor)
                # replaces volumes from inputs_copy[0] with volumes from rounded_propagated
                dict_to_tensor(inputs_copy[0], rounded_propagated)
                # print('input shape', inputs_copy.shape, inputs_copy[0].shape, inputs_copy[0].unsqueeze(0).shape)
                rounded_outputs = net(inputs_copy[0].unsqueeze(0))
                if cap_bound:
                    # print('cap bound: ', cap_bound)
                    capital = capital_criterion(fast_data, rounded_perturbation_tensor.detach())
                    if capital >= cap_bound:
                        # print('exceeds bound. step: ', step)
                        perturbation_tensor = old_pert
                        break
                    else:
                        # print('does not exceed bound, step: ', step)
                        old_pert = copy.deepcopy(perturbation_tensor.detach())
                if rounded_outputs.argmax().item() != targets[0].item() or (rounded_outputs.argmax().item()==target and target is not False):
                    break

                perturbation_tensor = perturbation_tensor.requires_grad_()

                # propagates perturbation tensor as perturbation to fast_data
                propagated = propagator(fast_data, perturbation_tensor)

                # replaces volumes from inputs_copy[0] with volumes from propagated
                dict_to_tensor(inputs_copy[0], propagated)
                outputs = net(inputs_copy[0].unsqueeze(0))
                if cost_coeff != 0.0:
                    cost = cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx)
                else:
                    cost = 0.0
                if detectability_coeff != 0.0:
                    detectability = detectability_criterion(propagator(fast_data, perturbation_tensor))
                else:
                    detectability = 0.0
                if capital_coeff!=0.0:
                    capital = capital_criterion(fast_data, perturbation_tensor)
                else:
                    capital = 0.0
                if target == None:
                    loss = criterion(outputs, targets[0].unsqueeze(0))
                else:
                    loss = -1.0*criterion(outputs, target*torch.ones_like(targets[0].unsqueeze(0)))
                loss -= cost_coeff*cost + detectability_coeff*detectability + capital_coeff*capital
            grad = torch.autograd.grad(loss, [perturbation_tensor])[0].detach()
            if signed:
                grad = torch.sign(grad)
            perturbation_tensor += get_step_size(step_size, rand=rand_step_size)*grad # adds step to perturbation_tensor
            perturbation_tensor = torch.clamp(perturbation_tensor, min=0.0).detach()
            counter+=1
        # print('iterations done', counter)
        perturbation_tensor = (perturbation_tensor+rounding_threshold).round()
        propagated = propagator(fast_data, perturbation_tensor)
        dict_to_tensor(inputs[0], propagated)
        cost = cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx)
        capital_required = capital_criterion(fast_data, perturbation_tensor)
        detectability = detectability_criterion(propagated)
    if visualize:
        inputs_unpropagated = copy.deepcopy(inputs)
        dict_to_tensor(inputs_unpropagated[0], [fast_data[0], fast_data[1], perturbation_tensor])
        return inputs, perturbation_tensor, cost, capital_required, detectability, inputs_unpropagated
    else:
        return inputs, perturbation_tensor, cost, capital_required, detectability


def attacker_UniversalPerturbation_costReg(perturbation_inputs, net, inputs, targets, dataset_idx, dataset, fast_path, num_steps = 20, step_size=20.0,
                     device='cuda', signed = False, cost_coeff = 0.0, rand_step_size=True, detectability_coeff = 0.0, capital_coeff = 0.0,
                     criterion = torch.nn.CrossEntropyLoss(), target=None):
    """smart attack on volume which propagates perturbations. """
    inputs_copy = copy.deepcopy(inputs)
    cost_criterion = CalculateCost()
    capital_criterion = CalculateCapital()
    with torch.no_grad():
        fast_data = open_fast_data(fast_path, dataset, dataset_idx[0].item())
        fast_data[1], fast_data[2] = fast_data[1].to(device), fast_data[2].to(device)
        detectability_criterion = ProportionOccupied(fast_data)
        propagator = PropagateVolumePerturbation().to(device)
        old_perturbation_tensor = torch.zeros_like(fast_data[2])
        inputs_to_perturbation_tensor(fast_data, perturbation_inputs[0], old_perturbation_tensor)
        new_perturbation_tensor = torch.zeros_like(old_perturbation_tensor)
        for step in range(num_steps):
            with torch.enable_grad():
                new_perturbation_tensor = new_perturbation_tensor.requires_grad_()
                perturbation_tensor = new_perturbation_tensor+old_perturbation_tensor
                propagated = propagator(fast_data, perturbation_tensor)
                # replaces volumes from inputs_copy[0] with volumes from propagated
                dict_to_tensor(inputs_copy[0], propagated)
                outputs = net(inputs_copy[0].unsqueeze(0))
                if outputs.argmax().item() != targets[0].item():
                    break
                if cost_coeff!=0.0:
                    cost = cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx)
                else:
                    cost = 0.0
                if detectability_coeff!=0.0:
                    detectability = detectability_criterion(propagator(fast_data, perturbation_tensor))
                else:
                    detectability = 0.0
                capital = capital_criterion(fast_data, perturbation_tensor)
                if target == None:
                    loss = criterion(outputs, targets[0].unsqueeze(0))
                else:
                    loss = -1.0*criterion(outputs, target*torch.ones_like(targets[0].unsqueeze(0)))
                loss -= cost_coeff*cost + detectability_coeff*detectability + capital_coeff*capital
            grad = torch.autograd.grad(loss, [new_perturbation_tensor])[0].detach()
            if signed:
                grad = torch.sign(grad)
            perturbation_tensor += get_step_size(step_size, rand=rand_step_size)*grad # adds step to perturbation_tensor
            new_perturbation_tensor = torch.clamp(perturbation_tensor, min=0.0).detach()-old_perturbation_tensor.detach()
        new_perturbation_inputs = torch.zeros_like(inputs)
        new_perturbation_inputs[0][fast_data[1][:, 0], fast_data[1][:, 1]] = new_perturbation_tensor[:]
        return new_perturbation_inputs


def universalPerturbation_costReg(net, dataset, fast_path, attack_batch_size=50, outer_steps=10000, inner_steps=5, outer_step_size=1.5, inner_step_size=10.0, device='cuda',
                                  signed = False, cost_coeff = 0.0, rand_step_size=True, detectability_coeff = 0.0, capital_coeff = 0.0,
                                  criterion = torch.nn.CrossEntropyLoss(), load_path='', save_path='debug.t7', target=None):
    """ Universal perturbation on volume which propagates perturbations."""

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
    # net.eval()
    perturbation_inputs = torch.zeros_like(dataset[0][0]).unsqueeze(0).to(device)
    if load_path:
        perturbation_inputs = torch.load(load_path)
    running_perturbation_inputs = copy.deepcopy(perturbation_inputs)
    iterations_done = 0

    while iterations_done < outer_steps:
    #while False:
        for batch_idx, (inputs, targets, dataset_idx) in tqdm(enumerate(dataloader), leave=False):
            if iterations_done == outer_steps:
                break
            iterations_done += 1
            inputs, targets = inputs.to(device), targets.to(device)
            # print('batch', batch_idx)
            current_perturbation_inputs = attacker_UniversalPerturbation_costReg(perturbation_inputs, net, inputs, targets, dataset_idx, dataset,
                                                                                 fast_path, num_steps = inner_steps, step_size=inner_step_size,
                                                                                 device=device, signed = signed, cost_coeff = cost_coeff,
                                                                                 rand_step_size=rand_step_size, detectability_coeff = detectability_coeff,
                                                                                 capital_coeff=capital_coeff, criterion=criterion, target=target)
            running_perturbation_inputs = running_perturbation_inputs.detach()+outer_step_size*current_perturbation_inputs.detach()
            if (iterations_done+1) % attack_batch_size == 0:
                perturbation_inputs = running_perturbation_inputs
                print('max', torch.max(perturbation_inputs))
                print('mean', torch.mean(perturbation_inputs))
                print('std', torch.std(perturbation_inputs))

            #print('done batch', batch_idx)
            torch.save((perturbation_inputs).round(), save_path)
            #print('would save here')
        #torch.save((perturbation_inputs).round(), save_path)
    return perturbation_inputs.round()


def measure_universal_stats(perturbation_inputs, fast_path, dataset, dataset_idx, device):

    fast_data = open_fast_data(fast_path, dataset, dataset_idx[0].item())
    fast_data[1], fast_data[2] = fast_data[1].to(device), fast_data[2].to(device)
    perturbation_tensor = create_perturbation_tensor(fast_data).to(device)
    inputs_to_perturbation_tensor(fast_data, perturbation_inputs[0], perturbation_tensor)
    propagator = PropagateVolumePerturbation().to(device)
    propagated = propagator(fast_data, perturbation_tensor)
    cost_criterion = CalculateCost()
    capital_calculator = CalculateCapital()
    detectability_criterion = ProportionOccupied(fast_data)
    cost = cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx)
    capital_required = capital_calculator(fast_data, perturbation_tensor)
    detectability = detectability_criterion(propagated)
    return cost, capital_required, detectability


def propagate_universal_perturbation(universal_perturbation, fast_path, dataset, dataset_idx, device):
    copied_universal_perturbation = copy.deepcopy(universal_perturbation)
    fast_data = open_fast_data(fast_path, dataset, dataset_idx[0].item())
    fast_data[1], fast_data[2] = fast_data[1].to(device), fast_data[2].to(device)
    perturbation_tensor = torch.zeros_like(fast_data[2])
    inputs_to_perturbation_tensor(fast_data, copied_universal_perturbation[0], perturbation_tensor)
    propagator = PropagateVolumePerturbation().to(device)
    propagated = propagator(fast_data, perturbation_tensor)
    dict_to_tensor(copied_universal_perturbation[0], propagated)
    return copied_universal_perturbation




'''
def volumePropagationPGD(net, inputs, targets, dataset_idx, dataset, steps=30, step_size=0.1, clamp_start=0.0,
                         clamp_end=1.0, open_path='./data/fast_data', device='cuda', clamp=True, rounding=True,
                         l1_coeff=0.0, perturbation_perturbation=0.0, grad_threshold=0.0):

    criterion = torch.nn.MSELoss()

    with torch.no_grad():
        fast_data = open_fast_data(open_path, dataset, dataset_idx[0].item())
        fast_data[1], fast_data[2] = fast_data[1].to(device), fast_data[2].to(device)
        perturbation_tensor = create_perturbation_tensor(fast_data).to(device).requires_grad_()
        if clamp:
            clamp_tensor = torch.ones_like(perturbation_tensor)
        propagator = PropagateVolumePerturbation().to(device)
        for step in range(steps):
            with torch.enable_grad():
                perturbation_tensor = perturbation_tensor.requires_grad_()

                # propagates perturbation tensor as perturbation to fast_data
                perturbed = propagator(fast_data, perturbation_tensor)

                # replaces volumes from inputs[0] with volumes from perturbed
                dict_to_tensor(inputs[0], perturbed)
                outputs = net(inputs[0].unsqueeze(0))
                loss = criterion(outputs, targets[0].unsqueeze(0))+l1_coeff*torch.norm(perturbation_tensor, p=1)

            grad = torch.autograd.grad(loss, [perturbation_tensor])[0]
            grad[grad<grad_threshold] = 0.0
            perturbation_tensor += step_size*torch.sign(grad.detach()) # adds step to perturbation_tensor
            if clamp:
                perturbation_tensor = torch.min(torch.max(perturbation_tensor, clamp_start*clamp_tensor), clamp_end*clamp_tensor)
        cost_criterion = CalculateCost()
        print('before', cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx))
        if rounding:
            perturbation_tensor = perturbation_tensor.round()
        perturbed = propagator(fast_data, perturbation_tensor)
        dict_to_tensor(inputs[0], perturbed)

        cost_criterion = CalculateCost()
        cost = cost_criterion(fast_data, perturbation_tensor, dataset, dataset_idx)
        print('after', cost)
        capital_calculator = CalculateCapital()
        capital_required = capital_calculator(fast_data, perturbation_tensor)
    return inputs, perturbation_tensor, cost, capital_required
    '''