import numpy as np
import torch
import model
import time
from scipy.stats import ttest_rel
import sys

    
def orientation_transform(l,w,h,num):
    if(num==0):
        return l,w,h
    elif(num==1):
        return l,h,w
    elif(num==2):
        return w,l,h
    elif(num==3):
        return w,h,l
    elif(num==4):
        return h,l,w
    else:
        return h,w,l
    
def overlaps_space(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    x=[x1,x2,x3,x4]
    y=[y1,y2,y3,y4]
    z=[z1,z2,z3,z4]
    x.sort()
    y.sort()
    z.sort()
    S=(y[2]-y[1])*(x[2]-x[1])*(z[2]-z[1])
    S1=(x3-x1)*(y3-y1)*(z3-z1)
    S2=(x4-x2)*(y4-y2)*(z4-z2)
    if(S==S1):
        return 1
    elif(S==S2):
        return 2
    else:
        return 0
    
def intersect_space(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    if(x2>=x3 or x1>=x4 or y2>=y3 or y1>=y4 or z2>=z3 or z1>=z4):
        return False
    else:
        return True
    
def solveES(space1,space2,lwhmin):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    x=[x1,x2,x3,x4]
    y=[y1,y2,y3,y4]
    z=[z1,z2,z3,z4]
    x.sort()
    y.sort()
    z.sort()
    newx=x[1]
    newy=y[1]
    newz=z[1]
    newxb=x[2]
    newyb=y[2]
    newzb=z[2]
    lst=[[x2,y2,z2,x4,newy,z4],[newxb,y2,z2,x4,y4,z4],[x2,newyb,z2,x4,y4,z4],[x2,y2,newzb,x4,y4,z4],[x2,y2,z2,newx,y4,z4],[x2,y2,z2,x4,y4,newz]]
    deleteidx=[]
    for i in range(len(lst)):
        if(min(lst[i][3]-lst[i][0],lst[i][4]-lst[i][1],lst[i][5]-lst[i][2])<lwhmin):
            deleteidx.append(i)
    deleteidx.sort(reverse=True)
    for i in deleteidx:
        lst.remove(lst[i])
    return lst


def train_or_test(MODEL,x,isTrain):
    batch_size=x.size()[0]
    num_item=x.size()[1]
    softmax=torch.nn.Softmax(dim=1)
    pro=torch.zeros(batch_size).to(device)
    templist=[i for i in range(batch_size)]
    KPlist=[[[0,0,0,L,W,H]] for i in range(batch_size)]
    maskseq=[[] for i in range(batch_size)]
    frontier=torch.zeros(batch_size,2,L,W).to(device)
    x11=x/lwhmax
    x12=x11[:,:,[0,2,1]]
    positionset=set(i for i in range(6*W))
    inputb=MODEL.clacb(x11)
    inputb2=MODEL.clacb(x12)
    orientation_list=[0,1,2,3,4,5]
    n=x11.size()[1]
    volume=torch.zeros(batch_size)
    height=torch.zeros(batch_size).to(device)
    lastheight=torch.zeros(batch_size).to(device)
    for jk in range(batch_size):
        volume[jk]=torch.sum(x[jk,:,0]*x[jk,:,1]*x[jk,:,2])/L/W
    volume=volume.to(device)
    for i in range(n):
        inputf=MODEL.FrontierEmbedding(frontier.detach().clone())
        ct21=MODEL.DecodingSequence(inputb,inputf,maskseq)
        ct22=MODEL.DecodingSequence(inputb2,inputf,maskseq)
        for jk in range(batch_size):
            ct21[jk,maskseq[jk]]=-np.inf
            ct22[jk,maskseq[jk]]=-np.inf
        ct21=softmax(ct21)
        ct22=softmax(ct22)
        ct2=(ct21+ct22)/2
        if(isTrain):
            seqidx=torch.multinomial(ct2,1)
            seqidx=torch.squeeze(seqidx)
            seqpro=ct2[templist,seqidx]
            pro=pro+torch.log(seqpro)
        else:
            seqidx=torch.max(ct2,1)[1]
        unmaskposition=[set() for jk in range(batch_size)]
        for jk in range(batch_size):
            maskseq[jk]=maskseq[jk]+[seqidx[jk].item()]
            l=x[jk,seqidx[jk],0]
            w=x[jk,seqidx[jk],1]
            h=x[jk,seqidx[jk],2]
            for j in KPlist[jk]:
                x1=j[0]
                y1=j[1]
                z1=j[2]
                x2=j[3]
                y2=j[4]
                z2=j[5]
                for k in orientation_list:
                    l2,w2,h2=orientation_transform(l,w,h,k)
                    if(x1+l2<=x2 and y1+w2<=y2 and z1+h2<=z2):
                        unmaskposition[jk].add(y1*6+k)
        maskposition=[list(positionset.difference(unmaskposition[jk])) for jk in range(batch_size)]
        if(isTrain):
            posidx,pospro=MODEL.DecodingPosition(seqidx,inputb,inputf,maskseq,maskposition,False)
            pro=pro+pospro
        else:
            posidx=MODEL.DecodingPosition(seqidx,inputb,inputf,maskseq,maskposition,True)
        y1=torch.div(posidx,6,rounding_mode='trunc')
        ori1=posidx.int()%6
        for jk in range(batch_size):
            zmin=1e6
            frontier[jk,:,:,:]=frontier[jk,:,:,:]*lastheight[jk]
            for j in KPlist[jk]:
                l=x[jk,seqidx[jk],0]
                w=x[jk,seqidx[jk],1]
                h=x[jk,seqidx[jk],2]
                if(j[1]==y1[jk]):
                    l2,w2,h2=orientation_transform(l,w,h,ori1[jk])
                    if(j[2]+h2<zmin and j[0]+l2<=L and j[1]+w2<=W):
                        x1=j[0]
                        z1=j[2]
                        x2=x1+l2
                        y2=y1[jk]+w2
                        z2=z1+h2
                        zmin=z2
                        z3=j[5]
            if(z3==H):
                frontier[jk,1,x1:x2,y1[jk]:y2]=z2
                height[jk]=torch.max(frontier[jk,1,:,:])
            frontier[jk,:,:,:]=frontier[jk,:,:,:]/height[jk]
            itKPpace=[x1,y1[jk].item(),z1,x2.item(),y2.item(),z2.item()]
            # if(jk==0):
            #     aa[i,0]=x1
            #     aa[i,1]=y1[jk]
            #     aa[i,2]=z1
            #     aa[i,3]=x2
            #     aa[i,4]=y2
            #     aa[i,5]=z2
            deletelst=[]
            newKPlist=[]
            for i1 in range(len(KPlist[jk])):
                if(intersect_space(itKPpace,KPlist[jk][i1])):
                    lst3=solveES(itKPpace,KPlist[jk][i1],lwhmin)
                    deletelst.append(i1)
                    if(len(lst3)>0):
                        for j in lst3:
                            newKPlist.append(j)
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                KPlist[jk].remove(KPlist[jk][i1])
            deletelst=[]
            for i1 in range(len(newKPlist)-1):
                for j in range(i1+1,len(newKPlist)):
                    if(intersect_space(newKPlist[i1],newKPlist[j])):
                        a=overlaps_space(newKPlist[i1],newKPlist[j])
                        if(a==1):
                            if(i1 not in deletelst):
                                deletelst.append(i1)
                        elif(a==2):
                            if(j not in deletelst):
                                deletelst.append(j)
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                newKPlist.remove(newKPlist[i1])
            deletelst=[]
            for i1 in range(len(newKPlist)):
                for j in range(len(KPlist[jk])):
                    if(intersect_space(newKPlist[i1],KPlist[jk][j])):
                        a=overlaps_space(newKPlist[i1],KPlist[jk][j])
                        if(a>0):
                            if(i1 not in deletelst):
                                deletelst.append(i1)
                                break
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                newKPlist.remove(newKPlist[i1])
            for i1 in newKPlist:
                KPlist[jk].append(i1)
        lastheight=height
    ru=volume/height
    # aa=aa.numpy()
    # np.savetxt('a.txt',aa,fmt='%d')
    if(isTrain):
        return ru,pro
    else:
        return ru
                      
if __name__=='__main__':
    seed=0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    N=100
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size=32
    maxepoch=10000
    L=120
    W=100
    lwhmax=50
    lwhmin=10
    H=N*lwhmax
    lr=1e-5
    originaldim=3
    embedingdim=128
    num_head=8
    FFNdim=512
    AFFNnum=3
    C=5
    inchannel=2
    outchannel=4
    outdim=6*W
    alpha=0.05
    MODEL=model.Attend2Pack(originaldim,embedingdim,num_head,FFNdim,AFFNnum,inchannel,outchannel,C,outdim,L,W).to(device)
    torch.save(MODEL.state_dict(),f'KPattend2_L{L}W{W}N{N}.pth')
    MODELG= model.Attend2Pack(originaldim,embedingdim,num_head,FFNdim,AFFNnum,inchannel,outchannel,C,outdim,L,W).to(device)
    MODELG.load_state_dict(torch.load(f'KPattend2_L{L}W{W}N{N}.pth'))
    optim=torch.optim.Adam(MODEL.parameters(),lr=lr)
    torch.autograd.set_detect_anomaly=True
    sche=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optim,T_max=maxepoch)
    start=time.time()
    for epoch in range(1,maxepoch+1):
        x=torch.randint(low=lwhmin,high=lwhmax+1,size=[batch_size,N,3])
        CW,proarr=train_or_test(MODEL,x,True)
        CWG=train_or_test(MODELG,x,False)
        loss=torch.mean((CWG-CW)*proarr)
        optim.zero_grad()
        loss.backward()
        optim.step()
        sche.step()
        if(epoch%10==0):
            x=torch.randint(low=lwhmin,high=lwhmax+1,size=[batch_size,N,3])
            ru2=train_or_test(MODEL,x,False)
            ru=train_or_test(MODELG,x,False)
            t,p=ttest_rel(ru,ru2,alternative='less')
            if(t<0 and p<=0.05):
                torch.save(MODEL.state_dict(),f'3DBPP/KPattend2_L{L}W{W}N{N}.pth')
                MODELG.load_state_dict(torch.load(f'3DBPP/KPattend2_L{L}W{W}N{N}.pth'))
            if(p>0.95):
                for lrr in optim.param_groups:
                    lrr['lr']*=0.95
            end = time.time()
            print('epoch:',epoch,',greedy_ru:',format(torch.mean(ru).item()*100,'.2f'),',sample_ru:',format(torch.mean(ru2).item()*100,'.2f'),'%,time:',format(end-start,'.2f'),
                  ',loss:',format(loss.item(),'.2f'),',pro:',format(torch.mean(proarr).item(),'.2f'),',t:',format(t,'.2f'),',p:',format(p*100,'.2f'),'%')
            start=time.time()


