from operator import index
import torch
import torch.nn as nn
import numpy as np

from imports import DEVICE


class SM(nn.Module):
    def __init__(self):
        super().__init__()
        
    
    def forward(self,batchpop1,batchpop2,minimize=True):
        '''
        实现选择操作,默认是最小化函数，若minimize=False,则为最大化目标值问题
        '''
        dim=batchpop1.shape[1]
        fit1=batchpop1[:,0,:,:]
        fit2=batchpop2[:,0,:,:]
        batchMask=fit1-fit2
        if minimize:
            batchMask[batchMask>=0]=0
            batchMask[batchMask<0]=1

        else:
            batchMask[batchMask<=0]=0
            batchMask[batchMask>0]=1
        batchMask=torch.unsqueeze(batchMask,1)
        batchMask=batchMask.repeat(1,dim,1,1)
        batchMask1=torch.ones_like(batchMask).to(DEVICE)-batchMask
        nextPop=batchpop1*batchMask+batchpop2*batchMask1
        return nextPop
            
        
                    
if __name__=='__main__':
   a=torch.from_numpy(np.array([[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]]))
   print(a.shape)
   a[a>2]=0
   print(a)
    
         
        