import torch
import torch.nn as nn
from crm import CRM
import imports
from sm import SM
from utils import *
import time

class EM(nn.Module):
    
    '''
    整合了CRM和EM的基本模块，直接构成DECN
    输入：
    种群（batchSize个），
    (dim+1)*w*h规模的种群
    输出：下一代的种群
    通道0代表个体适应度
    
    '''
    
    def __init__(self,batchSize=16,w=10,h=10,fitF=None,useRepair=True):
        super().__init__()
        self.batchSize=batchSize
        self.w=w
        self.h=h
        self.CRM=CRM(useRepaire=useRepair) 
        self.SM=SM()
        self.fitF=fitF
        
        
    def forward(self,batchPop,xlb,xub):
        '''
        sortIndiv：
        1）实现基于适应度的排序
        2）将种群整理成（batch,dim+1,w,h）格式
        
        crm:
        3）调用depth-wise卷积操作(batch,dim,w,h)
        4) 调用种群选择操作
        '''
        
        batchPop=sortIndiv(batchPop)
        batchChrom=batchPop[:,1:,:,:]
        batchOffChrom=self.CRM(batchChrom,xlb,xub)
        batchOffPop=calFitness(batchOffChrom,imports.TRAINFUN)
        nextPop=self.SM(batchPop,batchOffPop)
        return nextPop
        
        
        
    

if __name__=='__main__':
    data=torch.randn((2,2,10,10))
    net=EM(dim=2)
    data=net.calFitness(data)
    label=torch.zeros((2,2,10,10))
    lf=torch.nn.MSELoss()
    optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
    for i in range(1000):
        y=net(data)
        # loss=lf(y,label)
        loss=lf(y[:,1:,:,:],label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
        time.sleep(0.1)

    
    

