# random utility functions to be used throughout the repo

import torch

import numpy as np
from scipy.spatial import Delaunay, ConvexHull
from scipy.linalg import lstsq
from scipy.optimize import linprog

def generate_feature_map(dl, model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    all_feats = []
    all_targets = []
    for model_in, target in dl:
        with torch.no_grad():
            target = target.float()
            model_in = torch.flatten(model_in, start_dim=1)
            model_in = model_in.to(device)
            feats = model.feature_forward(model_in)
            all_feats.append(feats)
            all_targets.append(target)
    all_feats = torch.cat(all_feats, dim=0).cpu().numpy() # num_ex X num_neuron
    all_targets = torch.cat(all_targets, dim=0).cpu().numpy() # num_ex
    return all_feats, all_targets

def solve_feasible_point(feats, target, margin=0.0001):
    # solves convex hull membership using linear programming
    # everything is constrained to be positive by default
    # A_eq == feats
    # b_eq == target

    # inequality constraints determine if something is classified properly
    A_ub = feats
    b_ub = np.zeros(target.shape) - margin
    label_mask = np.ones(target.shape)
    label_mask[target > 0.5] = -1
    A_ub = A_ub * label_mask[:, None] 
  
    # simplex is the only equality constraint 
    num_hid = feats.shape[1]
    simplex = np.ones((1, num_hid))
    
    # the cost matrix is arbitrary, we only care if there is a feasible point
    c = np.ones(num_hid)
    return linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=simplex, b_eq=[1])

def check_in_hull(dl, model):
    feat_map, label_vec = generate_feature_map(dl, model)
    return solve_feasible_point(feat_map, label_vec)['success']

def in_hull(p, hull):
    if not isinstance(hull,Delaunay):
        hull = Delaunay(hull)
    return hull.find_simplex(p)>=0
