import torch.nn as nn
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
from functionset import *
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model=10, max_len=200):
        super(PositionalEncoding, self).__init__()
        # self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe=pe.to(DEVICE)/20
        self.pe.requires_grad_(False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x



class ATTENTION(nn.Module):
    def __init__(self,popSize=100):
        super().__init__()
        self.popSize=popSize
        self.attn=nn.Parameter(torch.randn((1,self.popSize,self.popSize)),requires_grad=True)

    def forward(self,x):
        y=self.attn.softmax(dim=-1)@x
        return y
    
  
  
class AttnWithFit (nn.Module):
    def __init__(self,popSize=100,hiddenDim=100):
        super().__init__()
        self.popSize=popSize
        self.attn=nn.Parameter(torch.randn((1,self.popSize,self.popSize)),requires_grad=True)
        self.q =nn.Sequential(
                              nn.Linear(1, hiddenDim),
                            #   nn.LayerNorm(normalized_shape=hiddenDim),
                            #   nn.ReLU(),
                            #   nn.Linear(hiddenDim,hiddenDim),
                            #   nn.LayerNorm(normalized_shape=dim)
                              )
                              
        self.k =nn.Sequential(
                              nn.Linear(1, hiddenDim),
                            #   nn.LayerNorm(normalized_shape=hiddenDim),
                            #   nn.ReLU(),
                            #   nn.Linear(hiddenDim,hiddenDim),
                            #   nn.LayerNorm(normalized_shape=dim)
                              )
        
        self.num_heads = 1
        self.F=nn.Parameter(torch.randn((2,)),requires_grad=True)
        
        
    def forward(self,x,fitx):
        B, N, C = fitx.shape
        q = self.q(fitx).view(B, N, self.num_heads, -1).permute(0, 2, 1, 3) #B，H，N，SEQ
        k = self.k(fitx).view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        fitattn = q @ k.transpose(2, 3) * (x.shape[-1] ** -0.5) #b,h,n,n   (a11,a12,aij,...,ann)
        fitattn = torch.squeeze(fitattn.softmax(dim=-1) ,dim=1)
        y1=self.attn.softmax(dim=-1)@x
        y2=fitattn@x
        y=y1*self.F.softmax(-1)[0]+y2*self.F.softmax(-1)[1]
        return y


class MHSA(nn.Module):
    PE=PositionalEncoding().to(DEVICE)
    def __init__(self, num_heads, dim):
        super().__init__()
        # Q, K, V 转换矩阵，这里假设输入和输出的特征维度相同
        # self.q = nn.Linear(dim, dim)
        # self.k = nn.Linear(dim, dim)
        # self.v = nn.Linear(dim, dim)
        self.dropout=nn.Dropout()
        self.q =nn.Sequential(
                              nn.Linear(dim, dim),
                            #   nn.LayerNorm(normalized_shape=dim),
                            #   nn.ReLU(),
                            #   nn.Linear(dim, dim),
                            #   nn.LayerNorm(normalized_shape=dim)
                              )
                              
        self.k = nn.Sequential(
                              nn.Linear(dim, dim),
                            #   torch.nn.LayerNorm(normalized_shape=dim),
                            #   nn.ReLU(),
                            #   nn.Linear(dim, dim),
                            #   nn.LayerNorm(normalized_shape=dim)
                              )
        self.v = nn.Sequential(
                              nn.Linear(dim, dim),
                            #   torch.nn.LayerNorm(normalized_shape=dim),
                            #   nn.ReLU(),
                            #   nn.Linear(dim, dim),
                            #   nn.LayerNorm(normalized_shape=dim)
                              )
        self.num_heads = num_heads
        self.ln=torch.nn.LayerNorm(normalized_shape=dim)
 
    def forward(self, x):
        B, N, C = x.shape
        # 生成转换矩阵并分多头
        x=MHSA0.PE(x)
        #b,n,s
        q = self.dropout(self.q(x)).view(B, N, self.num_heads, -1).permute(0, 2, 1, 3) #B，H，N，SEQ
        k = self.dropout(self.k(x)).view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        v = x.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        attn = q @ k.transpose(2, 3) * (x.shape[-1] ** -0.5) #b,h,n,n   (a11,a12,aij,...,ann)
        attn = attn.softmax(dim=-1) 
        z= (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)   #checked
        return z

class MHSA0(nn.Module):
    PE=PositionalEncoding(d_model=1).to(DEVICE)
    def __init__(self, num_heads, dim):
        super().__init__()
        num_heads=1
        dim=1  
        # Q, K, V 转换矩阵，这里假设输入和输出的特征维度相同
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        # self.q =nn.Sequential(
        #                       nn.Linear(dim, dim),
        #                       nn.Linear(dim, dim))
        # self.k = nn.Sequential(
        #                       nn.Linear(dim, dim),
        #                       nn.Linear(dim, dim))
        # self.v = nn.Sequential(
        #                       nn.Linear(dim, dim),
        #                       nn.Linear(dim, dim))
        self.num_heads = num_heads
        self.ln=torch.nn.LayerNorm(normalized_shape=dim)
 
    def forward(self, x):
        B, N, C = x.shape
        # 生成转换矩阵并分多头
        x=MHSA.PE(x)
        x=x.view(B,N,C,1).permute(0,2,1,3).reshape(-1,N,1)   #B*C ,N,1
        q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) #B*C，H，N，SEQ
        k = self.k(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        v = self.k(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        
        # 点积得到attention score
        attn = q @ k.transpose(2, 3) * (x.shape[-1] ** -0.5)
        attn = attn.softmax(dim=-1)
        
        # 乘上attention score并输出
        #                                                B*C,N,H,1               b*c,n,1                   b,c,n,1              b,n,c,1                  b,n,c
        v = (attn @ v).permute(0, 2, 1, 3).reshape(-1,N,1).reshape(B,C,N,1).permute(0,2,1,3).view(B,N,C)
        return v



class DETNBASE(nn.Module):
    def __init__(self):
        super().__init__()
    
    
    def sortpop(self,x,fun):
        fit=fun['fun'](x,fun['bias'])
        fitness,fit=torch.sort(fit,dim=-1)  #b,n
        y=torch.zeros_like(x)
        for index,pop in enumerate(x):
            pop=x[index]
            y[index]=torch.index_select(pop,0,fit[index])
            
        return y,fitness
    
    def repaire(self,x):
        x[x>self.xub]=self.xub
        x[x<self.xlb]=self.xlb
        return x     


class  TEM(DETNBASE):
    def __init__(self,num_heads=8,dim=64,hidden_dim=100,popSize=10):
        super().__init__()
        self.dim=dim
        # self.trm=MHSA(num_heads,dim)
        # self.trm=ATTENTION(popSize=popSize)
        self.trm=AttnWithFit(popSize=popSize,hiddenDim=hidden_dim)
        self.mut=nn.Sequential(
            nn.Linear(dim,dim),
            nn.ReLU(),
            nn.Linear(dim,dim)
        )
        
        self.sel=nn.Sequential(
            nn.Linear(popSize*2,popSize*4),
            nn.BatchNorm1d(popSize*4),
            nn.ReLU(),
            nn.Linear(popSize*4,popSize),
            nn.ReLU(),
            nn.Softmax(dim=1)
        )
        
        self.f1=nn.Parameter(torch.randn((1,popSize,1)),requires_grad=True)
        self.f2=nn.Parameter(torch.randn((1,popSize,1)),requires_grad=True)
        self.f3=nn.Parameter(torch.randn((1,popSize,1)),requires_grad=True)
        self.sm=SM()
        
    def nnsel(self,batchfather,batchOff,fun):
        f1=fun['fun'](batchfather,fun['bias'])
        batchOff,fit=self.sortpop(batchOff,fun)
        f2=fun['fun'](batchOff,fun['bias'])
        s1=self.sel(torch.cat((f1,f2),dim=1))
        s1=torch.unsqueeze(s1,2)
        s2=torch.ones_like(s1)-s1
        return s1*batchfather+s2*batchOff
    
    def forward(self,x,fun):
        b,n,d=x.shape
        fitx=fun['fun'](x,fun['bias']).softmax(dim=-1)
        fitx=fitx.view(b,n,1)
        crosspop=self.trm(x,fitx)   ##A & AF
        offpop=self.mut(crosspop)  ##NN   MUT
        off=self.f1*x+self.f2*crosspop+self.f3*offpop
        nextpop=self.sm(x,off,fun=fun)
        nextpop=off
        return nextpop



class SM(nn.Module):
    def __init__(self):
        super().__init__()
        
    
    def forward(self,batchpop1,batchpop2,minimize=True,fun=None):
        '''
        实现选择操作,默认是最小化函数，若minimize=False,则为最大化目标值问题
        '''
        
        fit1=fun['fun'](batchpop1,fun['bias'])
        fit2=fun['fun'](batchpop2,fun['bias'])
        batchMask=fit1-fit2   #b,n,1
        if minimize:
            batchMask[batchMask>=0]=0
            batchMask[batchMask<0]=1
        else:
            batchMask[batchMask<=0]=0
            batchMask[batchMask>0]=1
        batchMask=torch.unsqueeze(batchMask,2)
        batchMask1=torch.ones_like(batchMask).to(DEVICE)-batchMask
        nextPop=batchpop1*batchMask+batchpop2*batchMask1
        return nextPop
            


class  TEMonlyMHSA(nn.Module):
    def __init__(self,num_heads=8,dim=64,hidden_dim=100,popSize=10):
        super().__init__()
        self.dim=dim
        self.trm=MHSA(num_heads,dim)
        self.sm=SM()
        
    def forward(self,x,fun):
        offPop=self.trm(x)
        nextpop=self.sm(x,offPop,fun=fun)
        return nextpop



   


class DETNws3OnlyMhsa(DETNBASE):
    def __init__(self,num_heads=8,dim=64,hidden_dim=100,popSize=10,xlb=-100,xub=100):
        super().__init__()
        self.tem1=TEMonlyMHSA(num_heads,dim,hidden_dim,popSize)
        self.xlb=xlb
        self.xub=xub
        
    
    def forward(self,x,fun,xlb,xub):
        self.xlb=xlb
        self.xub=xub
        x,fit=self.sortpop(x,fun)
        x1=self.tem1(x,fun)
        x1=self.repaire(x1)
        
        x1,fit=self.sortpop(x1,fun)
        x2=self.tem1(x1,fun)
        x2=self.repaire(x2)
        
        x2,fit=self.sortpop(x2,fun)
        y=self.tem1(x2,fun)
        y=self.repaire(y)
        return y
        
        


class DETN(DETNBASE):
    def __init__(self,num_heads=8,dim=64,hidden_dim=100,popSize=10,ems=10,ws=False):
        super().__init__()
        self.ems=ems
        self.ws=ws
        if self.ws:
            self.detn=TEM(num_heads,dim,hidden_dim,popSize)
        else:
            self.detn=torch.nn.ModuleList([TEM(num_heads,dim,hidden_dim,popSize) for i in range(ems)])
        
    
    def forward(self,x,fun,xlb,xub,recordFit=None):
        self.xlb=xlb
        self.xub=xub
        fitlist=[]
        if  recordFit is None:
            if self.ws:
                for i in range(self.ems):
                    x,fit=self.sortpop(x,fun)
                    x=self.detn(x,fun)
                    x=self.repaire(x)
                y=x
            else:
                for i in range(self.ems):
                    x,fit=self.sortpop(x,fun)
                    x=self.detn[i](x,fun)
                    x=self.repaire(x)
                y=x
        else:
            if self.ws:
                for i in range(self.ems):
                    x,fit=self.sortpop(x,fun)
                    fitlist.append(torch.min(fit).item())
                    x=self.detn(x,fun)
                    x=self.repaire(x)
                y=x
            else:
                for i in range(self.ems):
                    x,fit=self.sortpop(x,fun)
                    fitlist.append(torch.min(fit).item())
                    x=self.detn[i](x,fun)
                    x=self.repaire(x)
                y=x
        
        return y



        
        


if __name__=='__main__':
    pass