import os
import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from pypoman import compute_polytope_vertices #Please install this package

import itertools

from scipy.spatial import Delaunay

def compute_vertices(input_dim, w1, Dict, mask):

    mask_l1 = mask[0:w1]
    mask_pn_l1 = 2*mask_l1-1
    
    # Derive bounding box inequalities
    Boundary_weight_1 = -np.eye(input_dim)
    Boundary_bias_1 = np.zeros((input_dim,))+1
    Boundary_weight_2 = np.eye(input_dim)
    Boundary_bias_2 = 2*np.ones((input_dim,))-1
    Boundary_weight =np.concatenate((Boundary_weight_1, Boundary_weight_2), axis=0)
    Boundary_bias =np.concatenate((Boundary_bias_1, Boundary_bias_2), axis=0)
    
    # Derive inequalities of neurons
    Inequalities_weight = np.multiply(Dict['layer1.weight'],-np.transpose(np.tile(mask_pn_l1, (input_dim,1))))
    Inequalities_bias = np.multiply(Dict['layer1.bias'], mask_pn_l1)
    

    # Combine all inequalities
    Inequalities_weight = np.concatenate((Inequalities_weight, Boundary_weight), axis=0)
    Inequalities_bias = np.concatenate((Inequalities_bias, Boundary_bias), axis=0)
      
    L_vertices = compute_polytope_vertices(Inequalities_weight, Inequalities_bias)
    
    return L_vertices


class MLP(nn.Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self, input_dim, w1):
    super().__init__()
    self.layer1 = nn.Linear(input_dim, w1)
    self.layer2 = nn.Linear(w1, 1)

    
  def forward(self, x):
    '''Forward pass'''
    out1 = F.relu(self.layer1(x))
    out2 = F.relu(self.layer2(out1))

    return out2
  
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)
        
        
if __name__ == '__main__':
  
  # Set fixed random number seed

  # Initialize the MLP
  input_dim = 3
  w1 = 10
  mlp = MLP(input_dim, w1)
  mlp.apply(init_weights)
  Dict = {}

  for name, param in mlp.named_parameters():
      print(name)
      
      Dict[name] = param.detach().numpy() 

  lst = list(itertools.product([0, 1], repeat=w1))
  Count = []
  
  for i in np.arange(2**(w1)):

      random_mask = np.array(lst[i])  # enumerating all hidden states of neurons
      L_vertices = compute_vertices(input_dim, w1, Dict, random_mask)

      if len(L_vertices)>0 :
         print(len(L_vertices))
         tri = Delaunay(np.array(L_vertices))
         print(tri.simplices.shape[0])
         Count.append(tri.simplices.shape[0])
      
         print('Sum:', sum(Count))
  
  
  
  
  
  
  
  
  
  
  
  
  
  