import numpy as np
import torch
import model
import time
from scipy.stats import ttest_rel
import sys
import math

def orientation_transform(l,w,h,num):
    if(num==0):
        return l,w,h
    elif(num==1):
        return l,h,w
    elif(num==2):
        return w,l,h
    elif(num==3):
        return w,h,l
    elif(num==4):
        return h,l,w
    else:
        return h,w,l
    
def overlaps_space(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    x=[x1,x2,x3,x4]
    y=[y1,y2,y3,y4]
    z=[z1,z2,z3,z4]
    x.sort()
    y.sort()
    z.sort()
    S=(y[2]-y[1])*(x[2]-x[1])*(z[2]-z[1])
    S1=(x3-x1)*(y3-y1)*(z3-z1)
    S2=(x4-x2)*(y4-y2)*(z4-z2)
    if(S==S1):
        return 1
    elif(S==S2):
        return 2
    else:
        return 0
    
def intersect_space(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    if(x2>=x3 or x1>=x4 or y2>=y3 or y1>=y4 or z2>=z3 or z1>=z4):
        return False
    else:
        return True
    
def solveES(space1,space2,lwhmin):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    x=[x1,x2,x3,x4]
    y=[y1,y2,y3,y4]
    z=[z1,z2,z3,z4]
    x.sort()
    y.sort()
    z.sort()
    newx=x[1]
    newy=y[1]
    newz=z[1]
    newxb=x[2]
    newyb=y[2]
    newzb=z[2]
    lst=[[x2,y2,z2,x4,newy,z4],[newxb,y2,z2,x4,y4,z4],[x2,newyb,z2,x4,y4,z4],[x2,y2,newzb,x4,y4,z4],[x2,y2,z2,newx,y4,z4],[x2,y2,z2,x4,y4,newz]]
    deleteidx=[]
    for i in range(len(lst)):
        if(min(lst[i][3]-lst[i][0],lst[i][4]-lst[i][1],lst[i][5]-lst[i][2])<0.5):
            deleteidx.append(i)
    deleteidx.sort(reverse=True)
    for i in deleteidx:
        lst.remove(lst[i])
    return lst

def train_or_test(MODEL,x,isTrain,scale):
    batch_size=x.size()[0]
    pro=torch.zeros(batch_size).to(device)
    KPlist=[[[0,0,0,L,W,H]] for i in range(batch_size)]
    maskseq=[[] for i in range(batch_size)]
    frontier=torch.zeros(batch_size,2,L,W).to(device)
    x1=x/lwhmax
    x1=x1.to(device)
    positionset=set(i for i in range(6*W))
    inputb=MODEL.clacb(x1,scale)
    orientation_list=[0,1,2,3,4,5]
    n=x1.size()[1]
    volume=torch.zeros(batch_size)
    height=torch.zeros(batch_size).to(device)
    lastheight=torch.zeros(batch_size).to(device)
    for jk in range(batch_size):
        volume[jk]=torch.sum(x[jk,:,0]*x[jk,:,1]*x[jk,:,2])/L/W
    volume=volume.to(device)
    for i in range(n):
        inputf=MODEL.FrontierEmbedding(frontier.detach().clone())
        if(isTrain):
            seqidx,seqpro=MODEL.DecodingSequence(inputb,inputf,maskseq,False)
            pro=pro+torch.log(seqpro)
        else:
            seqidx=MODEL.DecodingSequence(inputb,inputf,maskseq,True)
        unmaskposition=[set() for jk in range(batch_size)]
        positiondict=[{} for jk in range(batch_size)]
        for jk in range(batch_size):
            maskseq[jk]=maskseq[jk]+[seqidx[jk].item()]
            l=x[jk,seqidx[jk],0]
            w=x[jk,seqidx[jk],1]
            h=x[jk,seqidx[jk],2]
            for j in KPlist[jk]:
                x1=j[0]
                y1=j[1]
                z1=j[2]
                x2=j[3]
                y2=j[4]
                z2=j[5]
                flag=0
                for k in orientation_list:
                    l2,w2,h2=orientation_transform(l,w,h,k)
                    if(x1+l2<=x2 and y1+w2<=y2 and z1+h2<=z2):
                        unmaskposition[jk].add(y1*6+k)
                        flag=1
                if(flag==1):
                    if(y1 in positiondict[jk]):
                        positiondict[jk][y1].append((x1,z1,z2))
                    else:
                        positiondict[jk][y1]=[(x1,z1,z2)]
        maskposition=[list(positionset.difference(unmaskposition[jk])) for jk in range(batch_size)]
        if(isTrain):
            posidx,pospro=MODEL.DecodingPosition(seqidx,inputb,inputf,maskseq,maskposition,False)
            pro=pro+torch.log(pospro)
        else:
            posidx=MODEL.DecodingPosition(seqidx,inputb,inputf,maskseq,maskposition,True)
        y1=torch.div(posidx,6,rounding_mode='trunc')
        ori1=posidx.int()%6
        for jk in range(batch_size):
            zmin=1e6
            frontier[jk,:,:,:]=frontier[jk,:,:,:]*lastheight[jk]
            for j in positiondict[jk][y1[jk].item()]:
                if(j[1]<zmin):
                    xmin=j[0]
                    zmin=j[1]
                    zmax=j[2]
            l=x[jk,seqidx[jk],0]
            w=x[jk,seqidx[jk],1]
            h=x[jk,seqidx[jk],2]
            l2,w2,h2=orientation_transform(l,w,h,ori1[jk])
            if(zmax==H):
                frontier[jk,1,xmin:xmin+l2,y1[jk]:y1[jk]+w2]=zmin+h2
                height[jk]=torch.max(frontier[jk,1,:,:])
            frontier[jk,:,:,:]=frontier[jk,:,:,:]/height[jk]
            itKPpace=[xmin,y1[jk].item(),zmin,xmin+l2.item(),y1[jk].item()+w2.item(),zmin+h2.item()]
            deletelst=[]
            newKPlist=[]
            for i1 in range(len(KPlist[jk])):
                if(intersect_space(itKPpace,KPlist[jk][i1])):
                    lst3=solveES(itKPpace,KPlist[jk][i1],lwhmin)
                    deletelst.append(i1)
                    if(len(lst3)>0):
                        for j in lst3:
                            newKPlist.append(j)
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                KPlist[jk].remove(KPlist[jk][i1])
            deletelst=[]
            for i1 in range(len(newKPlist)-1):
                for j in range(i1+1,len(newKPlist)):
                    if(intersect_space(newKPlist[i1],newKPlist[j])):
                        a=overlaps_space(newKPlist[i1],newKPlist[j])
                        if(a==1):
                            if(i1 not in deletelst):
                                deletelst.append(i1)
                        elif(a==2):
                            if(j not in deletelst):
                                deletelst.append(j)
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                newKPlist.remove(newKPlist[i1])
            deletelst=[]
            for i1 in range(len(newKPlist)):
                for j in range(len(KPlist[jk])):
                    if(intersect_space(newKPlist[i1],KPlist[jk][j])):
                        a=overlaps_space(newKPlist[i1],KPlist[jk][j])
                        if(a>0):
                            if(i1 not in deletelst):
                                deletelst.append(i1)
                                break
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                newKPlist.remove(newKPlist[i1])
            for i1 in newKPlist:
                KPlist[jk].append(i1)
        lastheight=height
    ru=volume/height
    if(isTrain):
        return ru,pro
    else:
        return ru
                      
if __name__=='__main__':
    seed=0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    N=100
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size=64
    maxepoch=10000
    L=120
    W=100
    lwhmax=50
    lwhmin=10
    H=N*lwhmax
    lr=1e-4
    originaldim=3
    embedingdim=128
    num_head=8
    FFNdim=512
    AFFNnum=3
    C=10
    inchannel=2
    outchannel=4
    outdim=6*W
    MODEL=model.Attend2Pack(originaldim,embedingdim,num_head,FFNdim,AFFNnum,inchannel,outchannel,C,outdim,L,W).to(device)
    torch.save(MODEL.state_dict(),f'3DBPP/KPattend2_L{L}W{W}N{N}.pth')
    MODELG= model.Attend2Pack(originaldim,embedingdim,num_head,FFNdim,AFFNnum,inchannel,outchannel,C,outdim,L,W).to(device)
    MODELG.load_state_dict(torch.load(f'3DBPP/KPattend2_L{L}W{W}N{N}.pth'))
    optim=torch.optim.Adam(MODEL.parameters(),lr=lr)
    torch.autograd.set_detect_anomaly=True
    start=time.time()
    scale=1/math.sqrt(embedingdim//num_head)
    for epoch in range(1,maxepoch+1):
        x=torch.randint(low=lwhmin,high=lwhmax+1,size=[batch_size,N,3])
        CW,proarr=train_or_test(MODEL,x,True,scale)
        CWG=train_or_test(MODELG,x,False,scale)
        loss=torch.mean((CWG-CW)*proarr)
        optim.zero_grad()
        loss.backward()
        optim.step()
        print(time.time()-start)
        if(epoch%1==0):
            x=torch.randint(low=lwhmin,high=lwhmax+1,size=[batch_size,N,3])
            ru2=train_or_test(MODEL,x,False,scale)
            ru=train_or_test(MODELG,x,False,scale)
            t,p=ttest_rel(ru,ru2,alternative='less')
            if(p<=0.05):
                torch.save(MODEL.state_dict(),f'3DBPP/KPattend2_L{L}W{W}N{N}.pth')
                MODELG.load_state_dict(torch.load(f'3DBPP/KPattend2_L{L}W{W}N{N}.pth'))
            if(p>0.95):
                for lrr in optim.param_groups:
                    lrr['lr']*=0.95
            end = time.time()
            print('epoch:',epoch,',greedy_ru:',format(torch.mean(ru).item()*100,'.2f'),'%,sample_ru:',format(torch.mean(ru2).item()*100,'.2f'),'%,time:',format(end-start,'.2f'),
                  ',loss:',format(loss.item(),'.2f'),',pro:',format(torch.mean(proarr).item(),'.2f'),',p:',format(p*100,'.2f'),'%')
            start=time.time()


