import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import torch
from torch.utils.data import Dataset, DataLoader

# Set the seed
np.random.seed(1)

class SynDataSet(Dataset):
        '''
            Creating a dataset which will be used as input to the data loader during the training process.
            This is simply a wrapper around our original dataset.
        '''
        def __init__(self, dataset, use_S=False):
            self.n_samples = dataset.shape[0]
            X_np = np.array([d[0] for d in dataset])
            y_np = np.array([d[1] for d in dataset])
            self.X = torch.tensor(X_np, dtype=torch.float32)
            self.y = torch.tensor(y_np)
        
        def __getitem__(self, index):
            return self.X[index], self.y[index]
        
        def __len__(self):
            return self.n_samples

class DataGen():
    def gen(self, NUM_EXAMPLES=500, INPUT_DIM=2, input_pkl='data_syn1'):

        if os.path.exists(f'{input_pkl}_{NUM_EXAMPLES}_{INPUT_DIM}.pkl'):
            print('Using previously saved data ...')
            self.readPickle(NUM_EXAMPLES, INPUT_DIM, input_pkl)
        else:
            self.dataset = []

            # Generating feature vectors
            self.dataset_x = np.random.uniform(size=NUM_EXAMPLES*INPUT_DIM).reshape(NUM_EXAMPLES, INPUT_DIM)

            # Generating a random w and b. We'll use this to generate labels. Need to make sure it splits the dataset.
            self.w = np.array([1, -1]) + 0.01*np.random.uniform(size=INPUT_DIM)
            self.b = 0.25*np.random.uniform()

            # Generating labels
            self.dataset_y = []
            count = [0, 0] # What is the split of the data among positive and negative examples?
            for x in self.dataset_x:
                if np.dot(self.w, x) + self.b < 0:
                    count[0] += 1
                    self.dataset.append( (x, -1) )
                    self.dataset_y.append(0)
                else:
                    count[1] += 1
                    self.dataset.append( (x, 1) )
                    self.dataset_y.append(1)

            print(f"Label split: #0s = {count[0]}, #1s = {count[1]}")

            # For plotting lines given by <w1,x> + b1 = 0 for some w1, b1
            self.xx = np.linspace(-0.01, 1, 4)
            self.yy = self.svm_line_compute()
            self.dataset = np.array(self.dataset, dtype=object)
            self.savePickle(NUM_EXAMPLES, INPUT_DIM, input_pkl)

        print(f'Size of the dataset = {len(self.dataset)}')

    def svm_line_compute(self):
        return np.array([-1*(self.w[0]*xp + self.b)/self.w[1] for xp in self.xx])

    def visualize(self):
        '''
        Plots the 
            i) features with labels
            ii) true separating hyperplane
        '''
        plt.clf()
        plt.figure()

        plt.xlabel('x1')
        plt.ylabel('x2')
        plt.title('Data with separating hyperplane')
        plt.scatter(self.dataset_x[:,0], self.dataset_x[:,1], c=self.dataset_y, label="_nolegend_")
        plt.plot(self.xx, self.yy, color='red')
        # plt.show()
        
        plt.savefig(os.path.join('..', 'plots', 'syn1_data.png'))
    
    def readPickle(self, NUM_EXAMPLES=500, INPUT_DIM=2, input_pkl='data_syn1'):
        file = open(f'{input_pkl}_{NUM_EXAMPLES}_{INPUT_DIM}.pkl','rb')
        file = pickle.load(file, encoding='latin1')
        
        self.dataset = file['dataset']
        self.dataset_x = file['dataset_x']
        self.dataset_y = file['dataset_y']
        self.xx = file['xx']
        self.yy = file['yy']
        self.w = file['w']
        self.b = file['b']
    
    def savePickle(self, NUM_EXAMPLES=500, INPUT_DIM=2, input_pkl='data_syn1'):
        file = open(f'{input_pkl}_{NUM_EXAMPLES}_{INPUT_DIM}.pkl','wb')
        
        data = {}
        
        data['dataset'] = self.dataset
        data['dataset_x'] = self.dataset_x
        data['dataset_y'] = self.dataset_y
        data['xx'] = self.xx
        data['yy'] = self.yy
        data['w'] = self.w
        data['b'] = self.b
        pickle.dump(data, file)
    
    def pytorchDataset(self):
        return SynDataSet(self.dataset)


if __name__ == "__main__":
    gen = DataGen()
    gen.gen()
    dataset = gen.dataset
    print(f'Dataset: {dataset}')
    gen.visualize()