from math import inf
from auxfuncs_array import *
import torch


def gen_eval_batch(inds, cs, X_device, l_device, blocksize, device):
    """calculating the corresponding hyperplane of each combination and its 0-1 loss

    Returns:
        the best configuration among cs (a batch of conbinations)
    """
    M = cs.shape[0]

    N,D = X_device.shape
    blocks = [cs[i:i + blocksize] for i in range(0, M, blocksize)]
    num_blocks = len(blocks)

    # Create CUDA streams
    streams = [torch.cuda.Stream() for _ in range(num_blocks+1)]

    # List to store SSEs for each chunk
    result_chunk = [] 
    for block, stream in zip(blocks, streams):
        with torch.cuda.stream(stream):
           if  D <= block_size/2: 
               block_device = block.to(device)
               M_b = block_device.shape[0]
               A_block = X_device[block_device]#(M, D, D)
               
               # Reshape the extracted submatrices into the desired tensor shape 
               b_block = torch.ones(M_b, D, 1, device=device,dtype=torch.float64) 
           
               # Solve the batched linear systems
               ws = torch.linalg.solve(A_block, b_block) # (M, D, 1)
               ws = ws.squeeze(-1) #(M,D)
               ones = torch.ones(M_b, 1, device = block_device.device)  #(M,1)
               
               # ws = torch.cat((ws, neg_ones), dim=1) # (M, D+1)
               hyper_dists =  torch.matmul(ws, X_device.T) - ones # (M,N)
       
               # threshold = 1e-12
               # hyper_dists[hyper_dists.abs() <= threshold] = 0
       
               hyper_asgns =  torch.sign(hyper_dists).to(dtype =torch.int8)
       
               # Use advanced indexing to set the elements at the specified indices to 0
               hyper_asgns[torch.arange(M_b).unsqueeze(1), block_device] = l_device[block_device]


               losses_pos = (hyper_asgns != l_device.view(1,-1)).sum(dim=1)  # Shape: (M,)
               losses_neg = N-D-losses_pos
               losses = torch.minimum(losses_pos, losses_neg)
               loss_best, min_index = losses.min(dim=0)

               result_chunk.append((ws[min_index], block_device[min_index], loss_best))
           else:
            #   D_complement = len(inds) - D
              block_device = block.to(device)
              inds_device = torch.tensor(inds, device=device)
              M_b = block_device.shape[0]
              inds_mat = inds_device.repeat(M_b,1) # (M, D)
              # Create a mask for each element in inds_mat to check if it is in cs_cuda
              matches = (block_device.unsqueeze(2) == inds_device) 
              matches = torch.sum(matches, dim=1, dtype=torch.bool) # (M, D)
              block_device_complement = inds_mat[~matches].view(M_b,D) # (M, D)
            
  

              A_block = X_device[block_device_complement]#(M, D, D)
              
              # Reshape the extracted submatrices into the desired tensor shape 
              b_block = torch.ones(M_b, D, 1, device=device,dtype=torch.float64) 
          
              # Solve the batched linear systems
              ws = torch.linalg.solve(A_block, b_block) # (M, D, 1)
              ws = ws.squeeze(-1) #(M,D)
              ones = torch.ones(M_b, 1, device = block_device_complement.device)  #(M,1)
              
              # ws = torch.cat((ws, neg_ones), dim=1) # (M, D+1)
              hyper_dists =  torch.matmul(ws, X_device.T) - ones # (M,N)
      

      
              hyper_asgns =  torch.sign(hyper_dists).to(dtype =torch.int8)
      
              # Use advanced indexing to set the elements at the specified indices to 0
              hyper_asgns[torch.arange(M_b).unsqueeze(1), block_device_complement] = l_device[block_device_complement]
              
              losses_pos = (hyper_asgns != l_device.view(1,-1)).sum(dim=1)  # Shape: (M,)
              losses_neg = N-D-losses_pos
              losses = torch.minimum(losses_pos, losses_neg)

              loss_best, min_index = losses.min(dim=0)

              result_chunk.append((ws[min_index], block_device[min_index], loss_best))

    best_cnfg = min(result_chunk, key=lambda x: x[-1])
    return best_cnfg






def E01_ICE(inds, X, l, poly_degree, blocksize, device = 'cuda', verbose=True):
    """The exact 0-1 loss linear classification algorithm --incremental cell enumeration algorithm

    Args:
        inds (list): exact solution over a subset (indexed by inds) for dataset X, if inds = range(len(X)) then exact solution
        X ((N,D)): Input data matrix
        l ({1,-1}^N}): label vector
        blocksize (Int): Change the block size in parallel processing, avoid memory ineffciency in GPU
        device (cpu or cuda, optinal): operation carried on CPU or GPU. Defaults to cuda.
        verbose (bool, optional): _print algorithm process. Defaults to True.

    Returns:
       opt_cnfg = (w_best.cpu(), loss, comb): resulted optimal solution, consists of normal vector, 0-1 loss, combination
    """

    global block_size
    block_size = len(inds)

    l = torch.tensor(l)
    X = torch.tensor(X).to(torch.float64)

    N,D = X.shape
    N_block = len(inds)

    if block_size <= D:
        return (None,inf)

    N,D = X.shape
     
    X_embed = embedding(X, poly_degree)

    # move to cuda
    X_device = X_embed.to(device)
    l_device = l.to(device).to(torch.int8)
    
    # if run complete solution, warm up using svm
    if len(inds) == N: 
       inds = reorder_svm(X_embed,l)

    opt_cnfg = (None,inf)
    n = -1
    if D <= block_size/2: # check if dimension greater than N/2

        # initialization
        css = [torch.tensor([[]],dtype=torch.int32) for _ in range(D+1)] # list used to store combinations of data items, type [[[int]]]


        for i in inds:
            n +=1
            if verbose == True:
               print(f'This is stage {n}')
            for d in reversed(range(min(D,n+1))):
                if css[d+1].size(1)==0:
                    css[d+1] = upd_array(i, css[d])
                else:
                    css[d+1] = torch.cat((css[d+1], upd_array(i, css[d])),dim=0)
            
            if css[D].size(1)!=0: # if D-comb of data is not empty we can generate hyperplanes

                w_best, comb, loss = gen_eval_batch(inds, css[D], X_device, l_device, blocksize, device)
                css[D] = torch.tensor([[]])

                if loss == 0:
                    comb = comb.tolist()
                    opt_cnfg = (w_best.cpu(),loss,comb)
                    print(f'the optimal solution has been found {opt_cnfg}')
                    return opt_cnfg


                if loss <= opt_cnfg[1]:
                    opt_cnfg = (w_best.cpu(),loss,comb)
                    if verbose ==True:
                        print(f' the best configuration at stage {n} is {opt_cnfg}')

    else:
        D_complement = block_size-D
        # initialization
        css = [torch.tensor([[]],dtype=torch.int32) for _ in range(D_complement+1)] # list used to store combinations of data items, type [[[int]]]


        for i in inds:
            n +=1
            if verbose == True:
               print(f'This is stage {n}')
            for d in reversed(range(min(D_complement,n+1))):
                if css[d+1].size(1)==0:
                    css[d+1] = upd_array(i, css[d])
                else:
                    css[d+1] = torch.cat((css[d+1], upd_array(i, css[d])),dim=0)
            
            if css[D_complement].size(1)!=0: # if D-comb of data is not empty we can generate hyperplanes

                w_best,comb, loss = gen_eval_batch(inds, css[D_complement], X_device, l_device, blocksize, device)
                css[D_complement] = torch.tensor([[]])

                if loss == 0:
                    comb = comb.tolist()
                    opt_cnfg = (w_best.cpu(),loss,comb)
                    print(f'the optimal solution has been found {opt_cnfg}')
                    return opt_cnfg

                if loss <= opt_cnfg[1]:
                    opt_cnfg = (w_best.cpu(), loss, comb)
                    if verbose ==True:
                        print(f' the best configuration at stage {n} is {opt_cnfg}')
                # check if optimal solution has been found
    
    return opt_cnfg
                


