import os
import numpy as np
import torch
from torch import nn, optim
from torch.optim import optimizer
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pickle
import sys
import shutil
import copy

from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from mpl_toolkits.mplot3d import Axes3D

if __name__=="__main__":
    import path
    folder_path= (path.Path(__file__).abspath()).parent.parent
    sys.path.append(folder_path)

from models.attack_model_base import AttackModel


class Matrix(AttackModel):
    """
    Implementation of the matrix attack model.
    In this case, the perturbation for the point x is given by Ax where A is a n by n matrix and x is n-dimensional.
    There is further support for clamping to ensure that the input does not exceed the allowed input bounds.
    """
    def __init__(self, defender, rho=1, A_init=None, epochs=15):
        """
        Initialize the matrix

        Args:
            defender (Defender): The defender that the adversary is currently attacking.
            rho (int, optional): The balance parameter between defender loss and reconstruction loss. Defaults to 1.
            A_init (n-by-n PyTorch Tensor, optional): The initialization matrix. If None, initialized by identity
            epochs (int, optional): [description]. Defaults to 15.
        """
        self.A = A_init if A_init!=None else torch.eye(2)
        self.defender = defender
        self.rho = rho
        self.epochs = epochs
        self.requires_training = True
    
    def clone(self, requires_grad=False):
        """
        Clone this attack model
        Args:
            requires_grad (bool, optional): Is the matrix in the cloned model trainable? Defaults to False.

        Returns:
            new_attack_model: The cloned attack model
        """
        A_clone = torch.clone(self.A).detach().requires_grad_(requires_grad)
        cloned_attack_model = Matrix(self.defender, self.rho, A_clone, self.epochs)
        return cloned_attack_model
    
    def get_loss(self, points, labels, requires_mean=True, save_losses=False):
        """
        Get the loss of the attack model on the points.
        Args:
            points (Tensor): The set of points on which the loss is required
            labels (Tensor): The labels of the corresponding points
            requires_mean (bool, optional): Should the funcition return a tensor of losses or the mean. Defaults to False.
            save_losses (bool, optional): Save the reconstruction and defender loss in a variable
            
        Returns:
            total_loss: Vector of losses, if requires_mean is False, otherwise the mean of the losses.
        """
        perturbed_points = self.get_perturbed(points)
        defender_loss = self.defender.get_classifier_loss(perturbed_points, labels, requires_mean)
        reconstruction_loss = self.get_reconstruction_loss(points, perturbed_points, requires_mean)
        if save_losses:
            self.defender_loss = defender_loss
            self.reconstruction_loss = reconstruction_loss
        return reconstruction_loss - self.rho*defender_loss
    
    def train_on_set(self, points, labels, epochs=None):
        """
        Train the attack matrix for a given set of points and labels.

        Args:
            points (PyTorch Tensor): Tensor of points
            labels (PyTorch Tensor): Corresponding Tensor of labels
            epochs (int) : Number of epochs for which to train. Defaults to None which will be replaces with self.epochs.
        """
        optimizer = optim.Adam([self.A], lr=1e-3)
        
        if epochs is None:
            epochs = self.epochs
        
        if self.A.requires_grad==False:
            self.A.requires_grad_(True)
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            self.defender.zero_grad()
            
            loss = self.get_loss(points, labels, requires_mean=True)
            
            loss.backward()
            optimizer.step()
    
    def get_perturbed(self, points):
        """
        Perturb the given points using the attack model
        
        Args:
            points (Tensor): Tensor of points

        Returns:
            perturbed (Tensor): Corresponding perturbed points
        """
        perturbed = torch.clamp((self.A@points.unsqueeze(-1)).squeeze(), min=0, max=1)
        return perturbed
    
    def get_reconstruction_loss(self, points, perturbed_points, requires_mean=True):
        """
        Utility. Get the reconstruction loss of the set of points and the corresponding perturbed points.
        
        Args:
            points (Tensor): Tensor of original points
            perturbed_points (Tensor): Tensor of corresponding perturbed points
            requires_mean (bool, optional): If False, returns pointwise loss, else returns the sum. Defaults to True.

        Returns:
            reconstruction loss: Returns the mean loss if requires_mean is True else returns Tensor of pointwise loss
        """
        if requires_mean:
            return torch.mean(torch.norm(perturbed_points-points, p=2, dim=1))
        else:
            return torch.norm(perturbed_points-points, p=2, dim=1)

if __name__=="__main__":
    print("Testing the matrix attack model in matrix.py!")