import numpy as np
import math
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import time

def init_seeds(seed=0):
    torch.manual_seed(seed)


    if seed == 0:
        cudnn.deterministic = False
        cudnn.benchmark = True


def load_dataset():
    train_dataset = h5py.File('datasets/train_signs.h5', "r")
    train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # train set features
    train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # train set labels

    test_dataset = h5py.File('datasets/test_signs.h5', "r")
    test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # test set features
    test_set_y_orig = np.array(test_dataset["test_set_y"][:]) #  test set labels

    classes = np.array(test_dataset["list_classes"][:]) # the list of classes
    
    train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
    test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
    
    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes

def random_mini_batches_GCN(X, Y, L, mini_batch_size, seed):
    
    m = X.shape[0]
    mini_batches = []
    np.random.seed(seed)
    
    permutation = list(np.random.permutation(m))
    shuffled_X = X[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    shuffled_L1 = L[permutation, :].reshape((L.shape[0], L.shape[1]), order = "F")
    shuffled_L = shuffled_L1[:, permutation].reshape((L.shape[0], L.shape[1]), order = "F")

    num_complete_minibatches = math.floor(m / mini_batch_size)
    
    for k in range(0, num_complete_minibatches):       
        mini_batch_X = shuffled_X[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_L = shuffled_L[k * mini_batch_size : k * mini_batch_size + mini_batch_size, k * mini_batch_size : k * mini_batch_size + mini_batch_size]
        mini_batch = (mini_batch_X, mini_batch_Y, mini_batch_L)
        mini_batches.append(mini_batch)
    mini_batch = (X, Y, L) 
    mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches_GCN1(X, X1, Y, L, mini_batch_size, seed):
    
    m = X.shape[0]
    mini_batches = []
    np.random.seed(seed)
    
    permutation = list(np.random.permutation(m))
    shuffled_X = X[permutation, :]
    shuffled_X1 = X1[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    shuffled_L1 = L[permutation, :].reshape((L.shape[0], L.shape[1]), order = "F")
    shuffled_L = shuffled_L1[:, permutation].reshape((L.shape[0], L.shape[1]), order = "F")

    num_complete_minibatches = math.floor(m / mini_batch_size)
    
    for k in range(0, num_complete_minibatches):       
        mini_batch_X = shuffled_X[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_L = shuffled_L[k * mini_batch_size : k * mini_batch_size + mini_batch_size, k * mini_batch_size : k * mini_batch_size + mini_batch_size]
        mini_batch = (mini_batch_X, mini_batch_X1, mini_batch_Y, mini_batch_L)
        mini_batches.append(mini_batch)
    mini_batch = (X, X1, Y, L) 
    mini_batches.append(mini_batch)
    
    return mini_batches
        
def random_mini_batches(X1, X2, Y, mini_batch_size, seed):
    
    m = X1.shape[0]
    m1 = X2.shape[0]
    mini_batches = []
    np.random.seed(seed)
    
    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    
    permutation1 = list(np.random.permutation(m1))
    shuffled_X2 = X2[permutation1, :]
    
    num_complete_minibatches = math.floor(m1/mini_batch_size)
    
    mini_batch_X1 = shuffled_X1
    mini_batch_Y = shuffled_Y
      
    for k in range(0, num_complete_minibatches):        
        mini_batch_X2 = shuffled_X2[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]        
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_Y)
        mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches2(X1, X2, Y1, Y2, mini_batch_size, seed):
    
    m = X1.shape[0]
    m1 = X2.shape[0]
    mini_batches = []
    np.random.seed(seed)
    
    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    shuffled_Y1 = Y1[permutation, :].reshape((m, Y1.shape[1]))
    
    permutation1 = list(np.random.permutation(m1))
    shuffled_X2 = X2[permutation1, :]
    shuffled_Y2 = Y2[permutation1, :].reshape((m1, Y2.shape[1]))
    
    num_complete_minibatches = math.floor(m / mini_batch_size)
    mini_batch_size1 = math.floor(m1/num_complete_minibatches)
    
    mini_batch_X1 = shuffled_X1
    mini_batch_Y1 = shuffled_Y1
      
    for k in range(0, num_complete_minibatches):

        mini_batch_X2 = shuffled_X2[k * mini_batch_size1 : k * mini_batch_size1 + mini_batch_size1, :]   

        mini_batch_Y2 = shuffled_Y2[k * mini_batch_size1 : k * mini_batch_size1 + mini_batch_size1, :]   
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_Y1, mini_batch_Y2)
        mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches_single(X1, Y, mini_batch_size, seed):
    
    m = X1.shape[0]
    mini_batches = []
    np.random.seed(seed)
    
    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    #shuffled_X2 = X2[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    
    num_complete_minibatches = math.floor(m/mini_batch_size)
        
    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X1, mini_batch_Y)
        mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches_ccc(X1, X2, X1_FULL, X2_FULL, X1_UN, X2_UN, Y_P, Y, mini_batch_size, seed):
    
    m = X1.shape[0]
    m1 = X1_UN.shape[0]
    mini_batches = []
    np.random.seed(seed)
    
    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    shuffled_X2 = X2[permutation, :]
    shuffled_X1_FULL = X1_FULL[permutation, :]
    shuffled_X2_FULL = X2_FULL[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))

    permutation1 = list(np.random.permutation(m1))
    shuffled_X1_UN = X1_UN[permutation1, :]
    shuffled_X2_UN = X2_UN[permutation1, :]
    shuffled_X1_UN_FULL = Y_P[permutation1, :]
    
    num_complete_minibatches = math.floor(m/mini_batch_size)
    mini_batch_size1 = math.floor(m1/num_complete_minibatches)
    
    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2 = shuffled_X2[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X1_FULL = shuffled_X1_FULL[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2_FULL = shuffled_X2_FULL[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X1_UN = shuffled_X1_UN[k * mini_batch_size1 : k * mini_batch_size1 + mini_batch_size1, :]
        mini_batch_X2_UN = shuffled_X2_UN[k * mini_batch_size1 : k * mini_batch_size1 + mini_batch_size1, :]
        mini_batch_X1_UN_FULL = shuffled_X1_UN_FULL[k * mini_batch_size1 : k * mini_batch_size1 + mini_batch_size1, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_X1_FULL, mini_batch_X2_FULL, mini_batch_X1_UN, mini_batch_X2_UN, mini_batch_X1_UN_FULL, mini_batch_Y)
        mini_batches.append(mini_batch)
        
    if m % mini_batch_size != 0:
        mini_batch_X1 = shuffled_X1[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2 = shuffled_X2[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X1_FULL = shuffled_X1_FULL[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2_FULL = shuffled_X2_FULL[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X1_UN = shuffled_X1_UN[num_complete_minibatches * mini_batch_size1 : m1, :]
        mini_batch_X2_UN = shuffled_X2_UN[num_complete_minibatches * mini_batch_size1 : m1, :]
        mini_batch_X1_UN_FULL = shuffled_X1_UN_FULL[num_complete_minibatches * mini_batch_size1 : m1, :]
        mini_batch_Y = shuffled_Y[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_X1_FULL, mini_batch_X2_FULL, mini_batch_X1_UN, mini_batch_X2_UN, mini_batch_X1_UN_FULL, mini_batch_Y)
        
    return mini_batches

def random_mini_batches_un(X1, X2, X1_UN, X1_FULL, X2_FULL, Y, mini_batch_size, seed):
    """
    Creates a list of random minibatches from (X, Y)
    
    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (containing 0 if cat, 1 if non-cat), of shape (1, number of examples)
    mini_batch_size - size of the mini-batches, integer
    seed -- this is only for the purpose of grading, so that you're "random minibatches are the same as ours.
    
    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    
    m = X1.shape[0]
    m1 = X1_UN.shape[0] 
                 
    mini_batches = []
    np.random.seed(seed)
    

    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    shuffled_X2 = X2[permutation, :]
    shuffled_X1_FULL = X1_FULL[permutation, :]
    shuffled_X2_FULL = X2_FULL[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    
    permutation1 = list(np.random.permutation(m1))
    shuffled_X1_UN = X1_UN[permutation1, :]
    

    num_complete_minibatches = math.floor(m/mini_batch_size)
    mini_batch_size1 = math.floor(m1/num_complete_minibatches)
    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2 = shuffled_X2[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X1_UN = shuffled_X1_UN[k * mini_batch_size1 : k * mini_batch_size1 + mini_batch_size1, :]
        mini_batch_X1_FULL = shuffled_X1_FULL[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2_FULL = shuffled_X2_FULL[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_X1_UN, mini_batch_X1_FULL, mini_batch_X2_FULL, mini_batch_Y)
        mini_batches.append(mini_batch)
    

    if m % mini_batch_size != 0:
        mini_batch_X1 = shuffled_X1[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2 = shuffled_X2[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X1_UN = shuffled_X1_UN[num_complete_minibatches * mini_batch_size1 : m1, :]
        mini_batch_X1_FULL = shuffled_X1_FULL[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2_FULL = shuffled_X2_FULL[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_Y = shuffled_Y[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_X1_UN, mini_batch_X1_FULL, mini_batch_X2_FULL, mini_batch_Y)
        mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches_unimodal(X1, mini_batch_size, seed):
    """
    Creates a list of random minibatches from (X, Y)
    
    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (containing 0 if cat, 1 if non-cat), of shape (1, number of examples)
    mini_batch_size - size of the mini-batches, integer
    seed -- this is only for the purpose of grading, so that you're "random minibatches are the same as ours.
    
    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    
    m = X1.shape[0]                  # number of training examples
    mini_batches = []
    np.random.seed(seed)
    

    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]


    num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitionning
    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = mini_batch_X1
        mini_batches.append(mini_batch)
    

    if m % mini_batch_size != 0:
        mini_batch_X1 = shuffled_X1[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch = (mini_batch_X1)
        mini_batches.append(mini_batch)
    
    return mini_batches




def random_mini_batches_bimodal(X1, X2, X1_FULL, X2_FULL, mini_batch_size, seed):
    """
    Creates a list of random minibatches from (X, Y)
    
    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (containing 0 if cat, 1 if non-cat), of shape (1, number of examples)
    mini_batch_size - size of the mini-batches, integer
    seed -- this is only for the purpose of grading, so that you're "random minibatches are the same as ours.
    
    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    
    m = X1.shape[0]                  # number of training examples
    mini_batches = []
    np.random.seed(seed)
    

    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    shuffled_X2 = X2[permutation, :]
    shuffled_X1_FULL = X1_FULL[permutation, :]
    shuffled_X2_FULL = X2_FULL[permutation, :]


    num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitionning
    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2 = shuffled_X2[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X1_FULL = shuffled_X1_FULL[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2_FULL = shuffled_X2_FULL[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_X1_FULL, mini_batch_X2_FULL)
        mini_batches.append(mini_batch)
    

    if m % mini_batch_size != 0:
        mini_batch_X1 = shuffled_X1[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2 = shuffled_X2[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X1_FULL = shuffled_X1_FULL[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2_FULL = shuffled_X2_FULL[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_X1_FULL, mini_batch_X2_FULL)
        mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches_standard(X, Y, mini_batch_size, seed):
    """
    Creates a list of random minibatches from (X, Y)
    
    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (containing 0 if cat, 1 if non-cat), of shape (1, number of examples)
    mini_batch_size - size of the mini-batches, integer
    seed -- this is only for the purpose of grading, so that you're "random minibatches are the same as ours.
    
    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    
    m = X.shape[0]                  # number of training examples
    mini_batches = []
    np.random.seed(seed)
    

    permutation = list(np.random.permutation(m))
    shuffled_X = X[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    

    num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitionning
    for k in range(0, num_complete_minibatches):
        mini_batch_X = shuffled_X[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)
    

    if m % mini_batch_size != 0:
        mini_batch_X = shuffled_X[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_Y = shuffled_Y[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)
    
    return mini_batches

def random_mini_batches_standardtwoModality(X1, X2, Y, mini_batch_size,seed):
    """
    Creates a list of random minibatches from (X, Y)
    
    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (containing 0 if cat, 1 if non-cat), of shape (1, number of examples)
    mini_batch_size - size of the mini-batches, integer
    seed -- this is only for the purpose of grading, so that you're "random minibatches are the same as ours.
    
    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    
    m = X1.shape[0]                  # number of training examples
    mini_batches = []
    np.random.seed(seed)
    

    permutation = list(np.random.permutation(m))
    shuffled_X1 = X1[permutation, :]
    shuffled_X2 = X2[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))
    

    num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitionning
    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = shuffled_X1[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2 = shuffled_X2[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k * mini_batch_size : k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_Y)
        mini_batches.append(mini_batch)
    

    if m % mini_batch_size != 0:
        mini_batch_X1 = shuffled_X1[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_X2 = shuffled_X2[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch_Y = shuffled_Y[num_complete_minibatches * mini_batch_size : m, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_Y)
        mini_batches.append(mini_batch)
    
    return mini_batches

def Get_mini_batches_standard(X, Y, mini_batch_size):
    m = X.shape[0]
    mini_batches = []
    num_complete_minibatches = math.floor(m / mini_batch_size)

    for k in range(0, num_complete_minibatches):
        mini_batch_X = X[k * mini_batch_size: k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = Y[k * mini_batch_size: k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)


    if m % mini_batch_size != 0:
        mini_batch_X = X[num_complete_minibatches * mini_batch_size: m, :]
        mini_batch_Y = Y[num_complete_minibatches * mini_batch_size: m, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)
    return mini_batches


def Get_mini_batches(X1, X2, Y, mini_batch_size):
    m = X1.shape[0]
    mini_batches = []
    num_complete_minibatches = math.floor(m / mini_batch_size)

    for k in range(0, num_complete_minibatches):
        mini_batch_X1 = X1[k * mini_batch_size: k * mini_batch_size + mini_batch_size, :]
        mini_batch_X2 = X2[k * mini_batch_size: k * mini_batch_size + mini_batch_size, :]
        mini_batch_Y = Y[k * mini_batch_size: k * mini_batch_size + mini_batch_size, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_Y)
        mini_batches.append(mini_batch)


    if m % mini_batch_size != 0:
        mini_batch_X1 = X1[num_complete_minibatches * mini_batch_size: m, :]
        mini_batch_X2 = X2[num_complete_minibatches * mini_batch_size: m, :]
        mini_batch_Y = Y[num_complete_minibatches * mini_batch_size: m, :]
        mini_batch = (mini_batch_X1, mini_batch_X2, mini_batch_Y)
        mini_batches.append(mini_batch)
    return mini_batches


def convert_to_one_hot(Y, C):
    Y = np.eye(C)[Y.reshape(-1)].T
    return Y

def predict(X, parameters):
    
    W1 = tf.convert_to_tensor(parameters["W1"])
    b1 = tf.convert_to_tensor(parameters["b1"])
    W2 = tf.convert_to_tensor(parameters["W2"])
    b2 = tf.convert_to_tensor(parameters["b2"])
    W3 = tf.convert_to_tensor(parameters["W3"])
    b3 = tf.convert_to_tensor(parameters["b3"])
    
    params = {"W1": W1,
              "b1": b1,
              "W2": W2,
              "b2": b2,
              "W3": W3,
              "b3": b3}
    
    x = tf.placeholder("float", [12288, 1])
    
    z3 = forward_propagation(x, params)
    p = tf.argmax(z3)
    

    t1 = time_synchronized()
    pred = model(p)  # only get inference result
    t2 = time_synchronized()
    with tf.Session() as sess:
        prediction = model(p, feed_dict = {x: X})
        
    return prediction
    

def create_placeholders(n_x, n_y):
    """
    Creates the placeholders for the tensorflow session.
    
    Arguments:
    n_x -- scalar, size of an image vector (num_px * num_px = 64 * 64 * 3 = 12288)
    n_y -- scalar, number of classes (from 0 to 5, so -> 6)
    
    Returns:
    X -- placeholder for the data input, of shape [n_x, None] and dtype "float"
    Y -- placeholder for the input labels, of shape [n_y, None] and dtype "float"
    
    Tips:
    - You will use None because it let's us be flexible on the number of examples you will for the placeholders.
      In fact, the number of examples during test/train is different.
    """


    X = tf.placeholder("float", [n_x, None])
    Y = tf.placeholder("float", [n_y, None])

    
    return X, Y


def initialize_parameters():
    """
    Initializes parameters to build a neural network with tensorflow. The shapes are:
                        W1 : [25, 12288]
                        b1 : [25, 1]
                        W2 : [12, 25]
                        b2 : [12, 1]
                        W3 : [6, 12]
                        b3 : [6, 1]
    
    Returns:
    parameters -- a dictionary of tensors containing W1, b1, W2, b2, W3, b3
    """
    
    tf.set_random_seed(1)                              # so that your "random" numbers match ours
        

    W1 = tf.get_variable("W1", [25,12288], initializer = tf.contrib.layers.xavier_initializer(seed = 1))
    b1 = tf.get_variable("b1", [25,1], initializer = tf.zeros_initializer())
    W2 = tf.get_variable("W2", [12,25], initializer = tf.contrib.layers.xavier_initializer(seed = 1))
    b2 = tf.get_variable("b2", [12,1], initializer = tf.zeros_initializer())
    W3 = tf.get_variable("W3", [6,12], initializer = tf.contrib.layers.xavier_initializer(seed = 1))
    b3 = tf.get_variable("b3", [6,1], initializer = tf.zeros_initializer())


    parameters = {"W1": W1,
                  "b1": b1,
                  "W2": W2,
                  "b2": b2,
                  "W3": W3,
                  "b3": b3}
    
    return parameters


def compute_cost(z3, Y):
    """
    Computes the cost
    
    Arguments:
    z3 -- output of forward propagation (output of the last LINEAR unit), of shape (10, number of examples)
    Y -- "true" labels vector placeholder, same shape as z3
    
    Returns:
    cost - Tensor of the cost function
    """
    

    logits = torch.transpose(z3)
    labels = torch.transpose(Y)

    cost = torch.mean(nn.CrossEntropyLoss(logits = logits, labels = labels))

    
    return cost

class NetWork(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.linear = nn.Sequential(nn.Linear(),
                                    nn.LeakyReLU())


    def forward(self,x):
        x = self.linear(x)
        return x

class Loss_Func(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(self, z3, Y):
        logits = torch.transpose(z3)
        labels = torch.transpose(Y)

        cost = torch.mean(nn.CrossEntropyLoss(logits = logits, labels = labels))
        return cost

import matplotlib.pyplot as plt



def model(X_train, Y_train, X_test, Y_test, learning_rate = 0.0001,
          num_epochs = 1500, minibatch_size = 32, print_cost = True):
    """
    Implements a three-layer tensorflow neural network: LINEAR->RELU->LINEAR->RELU->LINEAR->SOFTMAX.
    
    Arguments:
    X_train -- training set, of shape (input size = 12288, number of training examples = 1080)
    Y_train -- test set, of shape (output size = 6, number of training examples = 1080)
    X_test -- training set, of shape (input size = 12288, number of training examples = 120)
    Y_test -- test set, of shape (output size = 6, number of test examples = 120)
    learning_rate -- learning rate of the optimization
    num_epochs -- number of epochs of the optimization loop
    minibatch_size -- size of a minibatch
    print_cost -- True to print the cost every 100 epochs
    
    Returns:
    parameters -- parameters learnt by the model. They can then be used to predict.
    """
    
    # ops.reset_default_graph()                         # to be able to rerun the model without overwriting tf variables
    torch.manual_seed(1)                             # to keep consistent results
    seed = 3                                          # to keep consistent results
    (n_x, m) = X_train.shape                          # (n_x: input size, m : number of examples in the train set)
    n_y = Y_train.shape[0]                                  # n_y : output size
    costs = []                                        # To keep track of the cost
    

    X, Y = create_placeholders(n_x, n_y)



    parameters = initialize_parameters()

    z3 = forward_propagation(X, parameters)

    cost = compute_cost(z3, Y)

    optimizer = optim.Adam(parameters, learning_rate = learning_rate)
    loss = Loss_Func()


    init = tf.global_variables_initializer()


    for epoch in range(num_epochs):

        minibatch_cost = 0.
        num_minibatches = int(m / minibatch_size) # number of minibatches of size minibatch_size in the train set
        seed = seed + 1
        minibatches = random_mini_batches(X_train, Y_train, minibatch_size, seed)

        for minibatch in minibatches:


                
            (minibatch_X, minibatch_Y) = minibatch
                

            # Forward Propagation

            loss = compute_cost(z3, Y)
            
            _ , temp_cost = sess.run([optimizer, cost], feed_dict={X: minibatch_X, Y: minibatch_Y})

            # Backward Propagation
            optimizer.zero_grad()  # Zeroing Gradients
            loss.backward()  # Backpropagation to Calculate Gradients
            optimizer.step()

            minibatch_cost += temp_cost / num_minibatches

            # Print the cost every epoch
            if print_cost == True and epoch % 100 == 0:
                print ("Cost after epoch %i: %f" % (epoch, minibatch_cost))
            if print_cost == True and epoch % 5 == 0:
                costs.append(minibatch_cost)
                

    plt.plot(np.squeeze(costs))
    plt.ylabel('cost')
    plt.xlabel('iterations (per tens)')
    plt.title("Learning rate =" + str(learning_rate))
    plt.show()


    parameters = sess.run(parameters)
    print ("Parameters have been trained!")


    correct_prediction = torch.eq(torch.argmax(z3), torch.argmax(Y))


    accuracy = torch.mean(correct_prediction.type(torch.float))

    print ("Train Accuracy:", accuracy.eval({X: X_train, Y: Y_train}))
    print ("Test Accuracy:", accuracy.eval({X: X_test, Y: Y_test}))
        
    return parameters    


def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()