import numpy as np
import torch
import model
import time
import sys

def gauss_qd(x,MODEL):
    a=-1
    batch_size=x.shape[0]
    xda=(x-a)/2
    xsa=(x+a)/2
    xda=xda.unsqueeze(1)
    xsa=xsa.unsqueeze(1).repeat(1,5)
    point=torch.tensor([[-0.90618,-0.538469,0,0.538469,0.90618]])
    input=torch.mm(xda,point)+xsa
    input=input.reshape(batch_size*5,1)
    output=MODEL(input).reshape(batch_size,5)
    weight=torch.tensor([[0.236927],[0.478629],[0.568889],[0.478629],[0.236927]])
    output=torch.mm(output,weight)
    output=output*xda
    return output

def cal_inital_entropy():
    softmax=torch.nn.Softmax(dim=1)
    q=torch.randn(batch_size,1,dim)
    y=torch.randn(batch_size,dim,N)
    xy=torch.bmm(q,y)
    xy=xy.squeeze()
    sexy=softmax(xy*inital)
    entropy=-torch.mean(torch.sum(sexy*torch.log(sexy+1e-15),dim=1))
    return entropy.item()

def cal_entropy(lam,x):
    softmax=torch.nn.Softmax(dim=1)
    q=torch.randn(batch_size,1,dim)
    de=0
    for i in range(batch_size):
        y=torch.randn(batch_size,dim,x[i])
        xy=torch.bmm(q,y)
        xy=xy.squeeze()
        sexy=softmax(xy*lam[i,0])
        entropy=-torch.mean(torch.sum(sexy*torch.log(sexy+1e-15),dim=1))
        de=de+(entropy-inital_entropy)**2
    return de/batch_size
        
    
if __name__=='__main__':
    seed=0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    batch_size=128
    N=100
    MAXN=1600
    maxepoch=1200
    lr=1e-3
    originaldim=1
    embedingdim=10
    AFFNnum=3
    MODEL=model.FC(originaldim,embedingdim,AFFNnum)
    optim=torch.optim.Adam(MODEL.parameters(),lr=lr)
    sche=torch.optim.lr_scheduler.ExponentialLR(optim,gamma=0.999)
    torch.autograd.set_detect_anomaly=True
    start=time.time()
    embedingdim=128
    num_head=8
    dim=embedingdim//num_head
    inital=1/np.sqrt(dim)
    inital_entropy=cal_inital_entropy()
    for epoch in range(0,maxepoch+1):
        x=torch.randperm(MAXN-N)+N+1
        x=x[0:batch_size]
        x_standard=(2*x-MAXN-N)/(MAXN-N)
        lam=inital+gauss_qd(x_standard,MODEL)
        loss=cal_entropy(lam,x)
        loss=torch.mean(loss)
        optim.zero_grad()
        loss.backward()
        optim.step()
        sche.step()
        if(epoch%100==0):
            end = time.time()
            print('epoch:',epoch,',time:',format(end-start,'.2f'),',loss:',format(loss.item(),'.4e'))
            start=time.time()
            torch.save(MODEL.state_dict(),f'te{N}.pth')


