import numpy as np
import torch
from itertools import combinations_with_replacement
from sklearn.svm import SVC


def embedding(X,poly_degree):
    """Embedding a data set to higher dimensional space

    Args:
        X ((N,D)): input data set
        poly_degree (int): the degree of polynomial for defining the hypersurface, k=1: original data set,  k>1: polynomial hypersurface of degree k, k=-1: circles

    Returns:
        X_embed: embedded datasets
    """
    N,D = X.shape
    X_hom = torch.cat((torch.ones(N,1),X), dim = 1)
    match poly_degree:
       case 1:
            return X
       case m if m >1:
            elements = [i for i in range(D+1)]
            monos = list(combinations_with_replacement(elements, m))# Gather elements from x using indices in monos
            monos = torch.tensor([i for i in monos])
            selected = X_hom[:,monos]  # Shape: (K, M)
            X_embed = torch.prod(selected, dim=-1)  # Shape: (K,)
            X_embed = X_embed[:,1:]
            # print(f'the dimension of the embedded data is {X_embed.shape[-1]}')
            return X_embed
       case -1:
            X_embed = X**2
            X_embed = torch.cat((X,X_embed), dim = 1)
            # print(f'the dimension of the embedded data is {X_embed.shape[-1]}')
            return X_embed



def pred_linear(X,w):
    N,D = X.shape
    X = torch.tensor(X)
    ones = torch.ones(N,)
    dists = torch.matmul( X, w ) - ones
    signs = torch.sign(dists).to(torch.int8)
    return signs



def upd_array(a, cs:torch.tensor):
    """update method for CGC combination generation
    Args:
        a (int): update value
        M (int): the number of configurations in x, equivalent to n+1 choose k
        cs (torch.tensor): combinations of size k
    Returns:
        torch.tensor: updated combinations
    """
    if cs.size(1)==0:
        cs_new = torch.tensor([[a]], dtype= cs.dtype, device = cs.device)
    else:
        M = cs.shape[0]
        cs_new = torch.zeros(M, cs.shape[1]+1, dtype= cs.dtype, device = cs.device )
        a_M = torch.full((M, ), a,dtype= cs.dtype, device = cs.device)
        cs_new[:,:-1] = cs
        cs_new[:,-1] = a_M
        # cs_new = torch.cat((cs_new, a_M), dim=1)
    return cs_new

def reorder_svm(X,y):
    
   
   # Train SVM with linear kernel
   svm = SVC(kernel='linear', random_state=42)
   svm.fit(X, y)
   
   # Get the normal vector w and bias b
   w = svm.coef_[0]  # Shape: (D,)
   b = svm.intercept_[0]  # Scalar
   
   w_torch = torch.from_numpy(w.copy())

   distances = torch.abs(torch.matmul(X, w_torch)-b)  # Shape: (N,)
   
   # Step 3: Sort distances and get indices
   sorted_indices = torch.argsort(distances)  # Indices for sorting from small to large
   

   return sorted_indices