import torch
import torch.nn as nn
from em import EM
import time
from tqdm import tqdm

class DECNnws15(nn.Module):
    def __init__(self,batchSize=16,w=10,h=10,xlb=-10,xub=10,useRepaire=True):
        super().__init__()
        self.w=w
        self.h=h
        self.batchSize=batchSize
        self.xlb=xlb
        self.xub=xub
       
        self.em1=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em2=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em3=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em4=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em5=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em6=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em7=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em8=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em9=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em10=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em11=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em12=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em13=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em14=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        self.em15=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        
        # self.em16=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em17=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em18=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em19=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em20=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em21=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em22=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em23=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em24=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em25=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em26=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em27=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em28=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em29=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        # self.em30=EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
    

   
        

    def forward(self,x0,xlb,xub):
        x1=self.em1(x0,xlb,xub)
        x2=self.em2(x1,xlb,xub)
        x3=self.em3(x2,xlb,xub)
        x4=self.em4(x3,xlb,xub)
        x5=self.em5(x4,xlb,xub)
        x6=self.em6(x5,xlb,xub)
        x7=self.em7(x6,xlb,xub)
        x8=self.em8(x7,xlb,xub)
        x9=self.em9(x8,xlb,xub)
        x10=self.em10(x9,xlb,xub)
        x11=self.em11(x10,xlb,xub)
        x12=self.em12(x11,xlb,xub)
        x13=self.em13(x12,xlb,xub)
        x14=self.em14(x13,xlb,xub)
        x15=self.em15(x14,xlb,xub)
        # x16=self.em16(x15,xlb,xub)
        # x17=self.em17(x16,xlb,xub)
        # x18=self.em18(x17,xlb,xub)
        # x19=self.em19(x18,xlb,xub)
        # x20=self.em20(x19,xlb,xub)
        # x21=self.em21(x20,xlb,xub)
        # x22=self.em22(x21,xlb,xub)
        # x23=self.em23(x22,xlb,xub)
        # x24=self.em24(x23,xlb,xub)
        # x25=self.em25(x24,xlb,xub)
        # x26=self.em26(x25,xlb,xub)
        # x27=self.em27(x26,xlb,xub)
        # x28=self.em28(x27,xlb,xub)
        # x29=self.em29(x28,xlb,xub)
        # x30=self.em30(x29,xlb,xub)
        return x15
    
            
            

class DECNws3(nn.Module):
    def __init__(self,batchSize=16,w=10,h=10,useRepaire=True):
        super().__init__()
        self.w=w
        self.h=h
        self.batchSize=batchSize
        self.decn= EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        
    def forward(self,data,xlb,xub):
        
        data=self.decn(data,xlb,xub)
        data=self.decn(data,xlb,xub)
        offPop=self.decn(data,xlb,xub)
        
        return offPop
    
    
class DECNws30(nn.Module):
    def __init__(self,batchSize=16,w=10,h=10,xlb=-10,xub=10,useRepaire=True):
        super().__init__()
        self.w=w
        self.h=h
        self.batchSize=batchSize
        self.xlb=xlb
        self.xub=xub
        self.decn= EM(batchSize=self.batchSize,w=self.w,h=self.h,useRepair=useRepaire)
        
    def forward(self,data,xlb,xub):
        
        for i in range(30):
            data=self.decn(data,xlb,xub)
        offPop=data
        
        return offPop

MODEL_DICT={
    'decnws3':DECNws3,
    'decnws30':DECNws30,
    'decnnws15':DECNnws15,
}


# if __name__=='__main__':
#     data=torch.randn((2,100,10,10)).to(DEVICE)
#     net=DECN().to(DEVICE)
#     data=calFitness(data)
#     label=torch.zeros((2,100,10,10)).to(DEVICE)
#     lf=torch.nn.MSELoss()
#     optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
#     for i in range(3):
#         y=net(data)
#         # loss=lf(y,label)
#         loss=lf(y[:,1:,:,:],label)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         print(loss.item())
#     x=torch.randn((1,100,10,10)).to(DEVICE)
#     x=calFitness(x)
#     bar=tqdm(range(50))
#     for i in bar:
#         x=net(x)
#         fitness=x[0][0,:,:]
#         bestFitness=torch.min(fitness)
#         bar.set_description('epooch | %d bestFitness=%0.6f'%(i,bestFitness))
#         time.sleep(1)

