import os
import numpy as np
import torch
from torch import nn, optim
from torch.optim import optimizer
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pickle
import sys
import shutil
import copy

from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from mpl_toolkits.mplot3d import Axes3D

if __name__=="__main__":
    import path
    folder_path= (path.Path(__file__).abspath()).parent.parent
    sys.path.append(folder_path)
    from classifier_base import Classifier
else:
    from models.classifier_base import Classifier

class Loss:
    def __init__(self, classifier):
        self.classifier = classifier
    
    def get_loss(self, inputs, labels, requires_mean=True):
        return None

class HingeLoss(Loss):
    def __call__(self, inputs, labels, requires_mean=True):
        '''
            Compute SVM loss according to the following:
            l(x_i, y_i) = max(0, 1 - y_i*score(x_i))
            where score(x_i) = <w, x_i> + b
            If requires_mean is True, then it returns the mean of the loss (useful when training)
            If requires_mean is False, then it simply returns the scores (useful for analysis, plotting)
        '''
        output = self.classifier(inputs)
        if not requires_mean:
            return torch.clamp((1 - labels*output.squeeze()), min=0)
        return torch.mean(torch.clamp((1 - labels*output.squeeze()), min=0))

class SmoothHingeLoss(Loss):
    def __init__(self, classifier, sigma=0.5):
        super().__init__(classifier)
        self.sigma = sigma
    
    def PhiM(self, v):
        numerator = v
        denominator = torch.sqrt(1+v**2)
        return (1 + (numerator/denominator))/2
    
    def phiM(self, v):
        denominator = 2*torch.sqrt(1+v**2)
        return (1/denominator)
    
    def __call__(self, inputs, labels, requires_mean=True):
        output = labels*self.classifier(inputs).squeeze()
        assert(labels.shape==output.shape) # sanity  check
        v = (1-output)/self.sigma
        loss = self.PhiM(v)*(1-output) + self.sigma*self.phiM(v)
        if requires_mean:
            return torch.mean(loss)
        return loss

class SVM(nn.Module, Classifier):
    '''
        SVM implementation
    '''
    def __init__(self, INPUT_DIM=2, loss_type="hinge", param_dict=None):
        super(SVM, self).__init__()
        self.svm = nn.Linear(INPUT_DIM, 1)
        self.loss_type = loss_type
        self.param_dict = param_dict
        self.loss_fn = self.parse_loss_type()
    
    def forward(self, inputs):
        return self.svm(inputs)
    
    def parse_loss_type(self):
        if self.loss_type=="hinge":
            return HingeLoss(self)
        elif self.loss_type=="smooth-hinge":
            sigma = 0.5
            if self.param_dict is not None and "sigma" in self.param_dict:
                sigma = self.param_dict["sigma"]
            return SmoothHingeLoss(self, sigma=sigma)
    
    def get_loss(self, inputs, labels, requires_mean=True):
        return self.loss_fn(inputs, labels, requires_mean)
    
    def predict(self, inputs):
        '''
            Returns predictions on input
            Output is torch tensor in {-1, 1}^len(inputs)
        '''
        with torch.no_grad():
            outputs = self.svm(inputs)
            preds = 2*(outputs>0) - 1
            return preds
    
    def compute_accuracy(self, preds, labels):
        '''
            Computes the accuracy given the predictions and the labels.
            Returns a number in [0, 1]
        '''
        with torch.no_grad():
            correct = torch.sum(preds.squeeze()==labels)
            return correct/labels.numel()
    
    def zero_grad(self):
        return self.svm.zero_grad()
    
    def train(self, dataset, train_args=None):
        """
        Train the classifier on the input dataset

        Args:
            dataset (Pytorch Datase): The input dataset
            train_args (dictionary, optional): Contains the required training arguments. Defaults to None.
        """
        if train_args:
            epochs = train_args["epochs"]
            batch_size = train_args["batch size"]
        else:
            epochs = 500
            batch_size = 32
        
        print("Training!")
        optimizer = optim.Adam(self.svm.parameters(), lr=1e-3, weight_decay=0.01)
        train_dl = DataLoader(dataset, batch_size=batch_size)
        # self.loss_dict = []
        # self.acc_dict = []
        for epoch in range(epochs):
            if (epoch+1)%50==0:
                print(f"Starting epoch {epoch+1} of {epochs}")
            
            for (X, y) in train_dl:
                output = self(X)
                # print(output)
                
                optimizer.zero_grad()
                loss = self.get_loss(X, y)
                
                loss.backward()
                optimizer.step()
            
            # with torch.no_grad():
            #     acc = self.compute_accuracy(self.predict(dataset.X), dataset.y)
            #     loss = self.get_loss(dataset.X, dataset.y)
            #     if (epoch+1)%50==0:
            #         print(f"\t...Loss is {loss.item():.3f} and accuracy is {acc.item():.2f}")
            #     self.acc_dict.append(acc)
            #     self.loss_dict.append(loss.item())
            # print(self.svm)

    def train_with_adv(self, dataset, S_curr, A_curr, orig_model, train_args=None):
        '''
        Train the classifier to defend the attacks
        '''
        S = dataset[S_curr]  ## samples in the set S  ## doubtful syntax
        num_D = len(dataset)
        num_S = len(S_curr)

        if train_args:
            epochs = train_args["epochs"]
            # batch_size = train_args["batch size"]
            reg = train_args["lambda"]
        else:
            epochs = 100
            # batch_size = 32
            reg = 0.01     # 1000
        print("Training!")
        optimizer = optim.Adam(self.svm.parameters(), lr=1e-3, weight_decay=0.01)
        # train_dl = DataLoader(dataset, batch_size=batch_size)
        epoch_loss_list = []
        svm_loss_list = []
        change_loss_list = []
        for epoch in range(epochs):
            if (epoch+1)%50==0:
                print(f"Starting epoch {epoch+1} of {epochs}")
            
            ## splitting the terms
            svm_loss = 0
            change_loss = 0
            X_s, y_s = S 
            A_x = A_curr.get_perturbed(X_s)
            svm_loss += self.get_loss(A_x, y_s)

            ## for the complete set D
            train_dl = DataLoader(dataset, batch_size=num_D)
            for (X,y) in train_dl:
                change_loss += torch.linalg.vector_norm(self.forward(X) - orig_model.forward(X))**2

            epoch_loss = svm_loss + change_loss*reg
            # Append float values, not tensors. 
            # Also, these things have requires_grad=True, so we need to be careful with them.
            epoch_loss_list.append(epoch_loss.item()) 
            svm_loss_list.append(svm_loss.item())
            change_loss_list.append(change_loss.item())

            optimizer.zero_grad()
            epoch_loss.backward()
            optimizer.step()
        
        
        plotdir = "svm_adv_loss_plots/"
        if not os.path.exists(plotdir):
            os.makedirs(plotdir)
        
        reg_dir = plotdir+str(reg)+"/"
        if not os.path.exists(reg_dir):
            os.makedirs(reg_dir)
        
        filename = reg_dir+"/svm_data_"+str(reg)+".pkl"
        
        x_range = np.arange(1,len(epoch_loss_list)+1)
        plt.figure()
        plt.plot(x_range, epoch_loss_list)
        plt.title(f'Total loss vs epochs (reg={reg})')
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.savefig(reg_dir+"/"+"Total_loss_"+str(reg)+".png")

        plt.figure()
        plt.plot(x_range, svm_loss_list)
        plt.title(f'SVM loss vs epochs (reg={reg})')
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.savefig(reg_dir+"/"+"SVM_loss_"+str(reg)+".png")

        plt.figure()
        plt.plot(x_range, change_loss_list)
        plt.title(f'Change loss vs epochs (reg={reg})')
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.savefig(reg_dir+"/"+"Change_loss_"+str(reg)+".png")

        plt.show()
        
        plotdir = "svm_adv_loss_plots/"
        if not os.path.exists(plotdir):
            os.makedirs(plotdir)
        
        reg_dir = plotdir+str(reg)+"/"
        if not os.path.exists(reg_dir):
            os.makedirs(plotdir)
        
        filename = reg_dir+"/svm_data_"+str(reg)+".pkl"
        
        data_dict = {}
        data_dict["Total"] = epoch_loss_list
        data_dict["SVM"] = svm_loss_list
        data_dict["Change"] = change_loss_list
        data_dict["reg"] = reg
        data_dict["epochs"] = epochs
        self.serialize_data(filename, data_dict)
    
    def serialize_data(self, fname, ddict):
        '''
            At the moment, we serialize only the following:
            1. Change loss dict
            2. SVM loss dict
            3. Total loss dict
            4. Initial and final w, b
        '''
        fp = open(fname, 'wb')
        pickle.dump(ddict, fp)
        fp.close()

def extractor(d):
        w, b = d.values()
        w = w.numpy().reshape(2,)
        b = float(b.numpy())
        return w, b


if __name__ == "__main__":

    NUM_EXAMPLES = 500
    INPUT_DIM = 2
    eps = 0
        
    # Set the seed
    np.random.seed(1)
    n = len(sys.argv)
    rho = 1 if n<=1 else float(sys.argv[1])
    print("rho is ", rho)
    rho_str = str(rho)
    if rho-int(rho)==0:
        rho = int(rho)

    # Generating synthetic data. Here's the process:
    # First we generate a bunch of x's which are drawn so that they are from the uniform distribution over [0, 1]^INPUT_DIM
    # Then, we pick a w and b and use that to label the points. This creates the dataset
    dataset_x = np.random.uniform(size=NUM_EXAMPLES*INPUT_DIM).reshape(NUM_EXAMPLES, INPUT_DIM)

    # Generating a random w and b. We'll use this to generate labels. Need to make sure it splits the dataset.
    w = np.array([1, -1]) + 0.01*np.random.uniform(size=INPUT_DIM)
    b = 0.25*np.random.uniform()

    # For plotting lines given by <w1,x> + b1 = 0 for some w1, b1
    xx = np.linspace(-0.01, 1, 4)
    def svm_line_compute(x_points, w_svm, b_svm):
        return np.array([-1*(w_svm[0]*xp + b_svm)/w_svm[1] for xp in x_points])
        
    yy = svm_line_compute(xx, w, b)
    plt.plot(xx, yy, color='red')

    # What is the split of the data among positive and negative examples?
    dataset = []
    dataset_y = []
    count = [0, 0]
    for x in dataset_x:
        if np.dot(w, x) + b < 0:
            count[0]+=1
            dataset.append((x, -1))
            dataset_y.append(0)
        else:
            count[1]+=1
            dataset.append((x, 1))
            dataset_y.append(1)

    plt.scatter(dataset_x[:,0], dataset_x[:,1], c=dataset_y, label="_nolegend_")

    print(count[0], count[1], "split")
    print(len(dataset))
    dataset = np.array(dataset, dtype=object)

    class SynDataSet(Dataset):
        '''
            Creating a dataset which will be used as input to the data loader during the training process.
            This is simply a wrapper around our original dataset.
        '''
        def __init__(self, dataset, use_S=False):
            self.n_samples = dataset.shape[0]
            X_np = np.array([d[0] for d in dataset])
            y_np = np.array([d[1] for d in dataset])
            self.X = torch.tensor(X_np, dtype=torch.float32)
            self.y = torch.tensor(y_np)
        
        def __getitem__(self, index):
            return self.X[index], self.y[index]
        
        def __len__(self):
            return self.n_samples
    
    ds = SynDataSet(dataset)
    
    svm = SVM(loss_type="smooth-hinge", param_dict={"sigma":0.4})
    d1 = copy.deepcopy(svm.state_dict())
    svm.train_model(ds, {"epochs":500, "batch size":32})
    
    with torch.no_grad():
        acc = svm.compute_accuracy(svm.predict(ds.X), ds.y)
        loss = svm.get_loss(ds.X, ds.y)
        print(f"Training complete...Loss is {loss.item():.3f} and accuracy is {acc.item():.2f}")