import os
import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from pypoman import compute_polytope_vertices # Please install this package
from sklearn.model_selection import train_test_split
import torch.optim as optim
from matplotlib import pyplot as plt
from scipy.special import yn, itairy, kn
import time as time
from sklearn.utils import shuffle

import itertools

from scipy.spatial import Delaunay

from scipy.spatial.distance import cdist

def compute_vertices(input_dim, w1, w2, Dict, mask):

    mask_l1 = mask[0:w1]
    mask_pn_l1 = 2*mask_l1-1
    
    mask_l2 = mask[w1:w1+w2]  
    mask_pn_l2 = 2*mask_l2-1
  
    layer2_real_weight = np.dot(Dict['layer2.weight'],np.multiply(Dict['layer1.weight'],np.transpose(np.tile(mask_l1, (input_dim,1))))) 
    layer2_real_bias = np.dot(Dict['layer2.weight'], np.multiply(mask_l1, Dict['layer1.bias'])) + Dict['layer2.bias']
  
    # Derive bounding box inequalities
    Boundary_weight_1 = -np.eye(input_dim)
    Boundary_bias_1 = np.zeros((input_dim,))+1
    Boundary_weight_2 = np.eye(input_dim)
    Boundary_bias_2 = 2*np.ones((input_dim,))-1
    Boundary_weight =np.concatenate((Boundary_weight_1, Boundary_weight_2), axis=0)
    Boundary_bias =np.concatenate((Boundary_bias_1, Boundary_bias_2), axis=0)
    
    # Derive inequalities of neurons
    Inequalities_weight = np.concatenate((np.multiply(Dict['layer1.weight'],-np.transpose(np.tile(mask_pn_l1, (input_dim,1)))),
                                          np.multiply(layer2_real_weight, -np.transpose(np.tile(mask_pn_l2, (input_dim,1))))), axis=0)
    Inequalities_bias = np.concatenate((np.multiply(Dict['layer1.bias'],mask_pn_l1), np.multiply(layer2_real_bias,mask_pn_l2)), axis=0)
    
    # Combine all inequalities
    Inequalities_weight = np.concatenate((Inequalities_weight, Boundary_weight), axis=0)
    Inequalities_bias = np.concatenate((Inequalities_bias, Boundary_bias), axis=0)
      
    L_vertices = compute_polytope_vertices(Inequalities_weight, Inequalities_bias)
    
    return L_vertices


def ComputeNumberOfSimplicesAndConvexRegions(input_dim, w1, w2, BigX_all, net):
    
    Mask_all = []
      
    for i in np.arange(BigSampleN*BigSampleN*BigSampleN):
        if i%(BigSampleN*BigSampleN)==0:
           print(i)
      
        out1, out2, _ = net(torch.from_numpy(BigX_all[i,:]).to(device))

        mask1 = np.heaviside(out1.detach().numpy(),0)
        mask2 = np.heaviside(out2.detach().numpy(),0)
      
        mask = np.concatenate((mask1, mask2), axis=0)
        Mask_all.append(mask)

    MA = np.array(Mask_all)
    MA_effective = np.unique(MA, axis=0)
  
    Count = []
  
    for j in np.arange(MA_effective.shape[0]):
      
        L_vertices = compute_vertices(input_dim, w1, w2, Dict, MA_effective[j,:])
        if   len(L_vertices)>=5:
             try:
                print(len(L_vertices))
                tri = Delaunay(np.array(L_vertices))
                print(tri.simplices.shape[0])
                Count.append(tri.simplices.shape[0])
             except:
                pass


    print('Simplices Sum:', sum(Count))
    print('Convex Region Sum:', len(Count))

    return   Count


class MLP(nn.Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self, input_dim, w1, w2):
    super().__init__()
    self.layer1 = nn.Linear(input_dim, w1)
    self.layer2 = nn.Linear(w1, w2)
    self.layer3 = nn.Linear(w2, 1)
    
  def forward(self, x):
    '''Forward pass'''
    out1 = F.relu(self.layer1(x))
    out2 = F.relu(self.layer2(out1))
    out3 = self.layer3(out2)   # The last layer is linear
    return out1, out2, out3

def init_weights(m):
    if isinstance(m, nn.Linear):
        #torch.nn.init.xavier_uniform(m.weight)
        #torch.nn.init.xavier_normal(m.weight)
        torch.nn.init.kaiming_normal(m.weight)
        #torch.nn.init.orthogonal(m.weight)
        m.bias.data.fill_(0.01)
  
  
if __name__ == '__main__':
  
  # Set fixed random number seed


  # Initialize the MLP
  input_dim = 3
  w1 = 40
  w2 = 20

  
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  
  EPOCHS = 400
  net = MLP(input_dim, w1, w2)
  net.apply(init_weights)
  net.to(device)
  criterion = nn.MSELoss()
  optimizer = optim.Adam(net.parameters(), lr=0.02)
  Loss_train = []
  Loss_test = []
  Count_CR = []
  Count_S = []
  BigSampleN = 100
  
  Dict = {}
  for name, param in net.named_parameters():

        Dict[name] = param.detach().numpy() 
    
  Bigx1 = np.linspace(-1,1,BigSampleN)
  Bigx2 = np.linspace(-1,1,BigSampleN)
  Bigx3 = np.linspace(-1,1,BigSampleN)
      
  BigX1, BigX2, BigX3 = np.meshgrid(Bigx1, Bigx2, Bigx3)
      
  BigX1_r = np.reshape(BigX1, (-1,1))
  BigX2_r = np.reshape(BigX2, (-1,1))
  BigX3_r = np.reshape(BigX3, (-1,1))
  
  BigX_all = np.float32(np.concatenate((BigX1_r, BigX2_r, BigX3_r), axis=1))  
    
  Count = ComputeNumberOfSimplicesAndConvexRegions(input_dim, w1, w2, BigX_all, net)
          
  Count_S.append(Count)

  
#%%
  fig, ax = plt.subplots(1)
  ax.hist(Count_S[0])

  plt.show()


  
  
  
  
  
  
  
  
  
  
  
  