import numpy as np
import torch
import model
import time
import sys
from scipy.stats import ttest_rel

    
def orientation_transform(l,w,h,num):
    lst1=[l,w,h]
    lst1.sort()
    l=lst1[0]
    w=lst1[1]
    h=lst1[2]
    if(num==0):
        return l,w,h
    elif(num==1):
        return l,h,w
    elif(num==2):
        return w,l,h
    elif(num==3):
        return w,h,l
    elif(num==4):
        return h,l,w
    else:
        return h,w,l
    
def overlaps_space(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    x=[x1,x2,x3,x4]
    y=[y1,y2,y3,y4]
    z=[z1,z2,z3,z4]
    x.sort()
    y.sort()
    z.sort()
    S=(y[2]-y[1])*(x[2]-x[1])*(z[2]-z[1])
    S1=(x3-x1)*(y3-y1)*(z3-z1)
    S2=(x4-x2)*(y4-y2)*(z4-z2)
    if(S==S1):
        return 1
    elif(S==S2):
        return 2
    else:
        return 0
    
def intersect_space(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    if(x2>=x3 or x1>=x4 or y2>=y3 or y1>=y4 or z2>=z3 or z1>=z4):
        return False
    else:
        return True
    
def solveES(space1,space2,lwhmin):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    y3 = space1[4]
    y4 = space2[4]
    z3 = space1[5]
    z4 = space2[5]
    x=[x1,x2,x3,x4]
    y=[y1,y2,y3,y4]
    z=[z1,z2,z3,z4]
    x.sort()
    y.sort()
    z.sort()
    newx=x[1]
    newy=y[1]
    newz=z[1]
    newxb=x[2]
    newyb=y[2]
    newzb=z[2]
    lst=[[x2,y2,z2,x4,newy,z4],[newxb,y2,z2,x4,y4,z4],[x2,newyb,z2,x4,y4,z4],[x2,y2,newzb,x4,y4,z4],[x2,y2,z2,newx,y4,z4],[x2,y2,z2,x4,y4,newz]]
    deleteidx=[]
    for i in range(len(lst)):
        if(min(lst[i][3]-lst[i][0],lst[i][4]-lst[i][1],lst[i][5]-lst[i][2])<lwhmin):
            deleteidx.append(i)
    deleteidx.sort(reverse=True)
    for i in deleteidx:
        lst.remove(lst[i])
    return lst

def intersect_space2D(space1,space2):
    x1 = space1[0]
    x2 = space2[0]
    y1 = space1[1]
    y2 = space2[1]
    z1 = space1[2]
    z2 = space2[2]
    x3 = space1[3]
    x4 = space2[3]
    if(x2>=z1 or x1>=z2 or y2>=x3 or y1>=x4):
        return False
    else:
        return True
    
def calc_per(y,x,x2,x3):
    batch_size=x.size()[0]
    n=x.size()[1]
    dim=x.size()[2]
    m=x2.size()[0]
    n2=y.size()[1]
    grad2=torch.autograd.grad(outputs=y,inputs=x,grad_outputs=torch.ones_like(y),retain_graph=True,create_graph=True)[0]
    grad2=grad2.view(batch_size,n*dim,1)
    grad3=torch.zeros(m,batch_size)
    for i in range(m):
        grad3[i,:]=torch.squeeze(torch.matmul(x2[i,:,:,:],grad2))
    grad4=torch.zeros(m,batch_size,n,dim)
    for i in range(m):
        grad4[i,:,:,:]=torch.autograd.grad(outputs=grad3[i,:],inputs=x,grad_outputs=torch.ones_like(grad3[i,:]),retain_graph=True,create_graph=True)[0]
    grad4=grad4.view(m,batch_size,1,n*dim)
    grad5=torch.zeros(m,batch_size)
    for i in range(m):
        grad5[i,:]=torch.squeeze(torch.bmm(grad4[i,:,:,:],x3[i,:,:,:]))
    grad6=torch.sum(grad5,dim=0)
    return grad6.unsqueeze(1).repeat(1,n).detach().clone()


def train_or_test(MODEL,x,L,W,H,lwhmax,attn_span,hidden_size,encoder_nb_layers,state_size,isTrain,context_size):
    batch_size=x.size()[0]
    num_item=x.size()[1]
    pro=torch.zeros(batch_size).to(device)
    lastheight=torch.zeros(batch_size).to(device)
    height=torch.zeros(batch_size).to(device)
    total_volume=torch.zeros(batch_size).to(device)
    ru=torch.zeros(batch_size)
    templist=[i for i in range(batch_size)]
    KPlist=[[[0,0,0,L,W,H]] for i in range(batch_size)]
    maskseq=[[] for i in range(batch_size)]
    Lset=set(i for i in range(L))
    Wset=set(i for i in range(W))
    packed_state=torch.zeros(batch_size,context_size,state_size).to(device)
    item_state0=x/lwhmax
    item_state0=item_state0.to(device)
    item_state1=item_state0[:,:,[1,2,0]]
    item_state=(item_state0+item_state1)/2
    item_state.requires_grad_(True)
    xx2=torch.zeros(2,batch_size,num_item,3)
    xx3=torch.zeros(2,batch_size,num_item,3)
    xx2[0,:,:,:]=item_state0
    xx2[1,:,:,:]=item_state1
    xx3[0,:,:,:]=item_state0-item_state
    xx3[1,:,:,:]=item_state1-item_state
    xx2=xx2.view(2,batch_size,1,num_item*3)
    xx3=xx3.view(2,batch_size,num_item*3,1)
    h_cache=[torch.zeros(batch_size,attn_span,hidden_size).to(device)  for i in range(encoder_nb_layers)]
    softmax=torch.nn.Softmax(dim=1)
    for i in range(num_item):
        s_out,actor_encoder_out,h_cache=MODEL.calc_seq_idx(packed_state.clone().detach(),item_state,h_cache)
        s_out=s_out+calc_per(s_out,item_state,xx2,xx3)
        for jk in range(batch_size):
            s_out[jk,maskseq[jk]]=-np.inf
        s_out=softmax(s_out)
        if(isTrain):
            seqidx=torch.multinomial(s_out,1)
            seqidx=torch.squeeze(seqidx)
            seqpro=s_out[templist,seqidx]
            pro=pro+torch.log(seqpro)
        else:
            seqidx=torch.max(s_out,1)[1]
        for jk in range(batch_size):
            maskseq[jk]=maskseq[jk]+[seqidx[jk]]
        select_item=item_state[templist,seqidx,:]
        select_item=torch.unsqueeze(select_item,dim=1)
        if(isTrain):
            oriidx,oripro=MODEL.calc_ori_idx(actor_encoder_out,select_item,False)
            pro=pro+torch.log(oripro)
        else:
            oriidx=MODEL.calc_ori_idx(actor_encoder_out,select_item,True)
        select_ori_item=torch.zeros(select_item.size()).to(device)
        unmaskx=[set() for _ in range(batch_size)]
        unmasky=[set() for _ in range(batch_size)]
        for jk in range(batch_size):
            select_ori_item[jk,:,0],select_ori_item[jk,:,1],select_ori_item[jk,:,2]=orientation_transform(select_item[jk,:,0],select_item[jk,:,1],select_item[jk,:,2],oriidx[jk])
            l2,w2,h2=orientation_transform(x[jk,seqidx[jk],0],x[jk,seqidx[jk],1],x[jk,seqidx[jk],2],oriidx[jk])
            l2,w2,h2=l2.item(),w2.item(),h2.item()
            total_volume[jk]=total_volume[jk]+l2/L*w2/W*h2
            for j in KPlist[jk]:
                x1=j[0]
                y1=j[1]
                z1=j[2]
                x2=j[3]
                y2=j[4]
                z2=j[5]
                if(l2+x1<=x2 and w2+y1<=y2 and h2+z1<=z2):
                    unmaskx[jk].add(x1)
        maskx=[list(Lset.difference(unmaskx[jk])) for jk in range(batch_size)]
        if(isTrain):
            xidx,xpro=MODEL.calc_x_idx(actor_encoder_out,select_ori_item,False,maskx)
            pro=pro+torch.log(xpro) 
        else:
            xidx=MODEL.calc_x_idx(actor_encoder_out,select_ori_item,True,maskx)
        for jk in range(batch_size):
            l2,w2,h2=orientation_transform(x[jk,seqidx[jk],0],x[jk,seqidx[jk],1],x[jk,seqidx[jk],2],oriidx[jk])
            l2,w2,h2=l2.item(),w2.item(),h2.item()
            for j in KPlist[jk]:
                if(j[0]==xidx[jk]):
                    x1=j[0]
                    y1=j[1]
                    z1=j[2]
                    x2=j[3]
                    y2=j[4]
                    z2=j[5]
                    if(l2+x1<=x2 and w2+y1<=y2 and h2+z1<=z2):
                        unmasky[jk].add(y1)    
        masky=[list(Wset.difference(unmasky[jk])) for jk in range(batch_size)]
        if(isTrain):
            yidx,ypro=MODEL.calc_y_idx(actor_encoder_out,select_ori_item,False,masky)
            pro=pro+torch.log(ypro)
        else:
            yidx=MODEL.calc_y_idx(actor_encoder_out,select_ori_item,True,masky)
        zidx=torch.zeros(batch_size,dtype=torch.int).to(device)+H   
        for jk in range(batch_size):
            hmin=2*H
            for j in KPlist[jk]:
                if(j[0]==xidx[jk] and j[1]==yidx[jk]):
                    x1=j[0]
                    y1=j[1]
                    z1=j[2]
                    x2=j[3]
                    y2=j[4]
                    z2=j[5]
                    l2,w2,h2=orientation_transform(x[jk,seqidx[jk],0],x[jk,seqidx[jk],1],x[jk,seqidx[jk],2],oriidx[jk])
                    if(l2+x1<=x2 and w2+y1<=y2 and h2+z1<=z2 and z1<zidx[jk]):
                        hmin=h2.item()
                        zidx[jk]=z1
            height[jk]=max(height[jk],zidx[jk]+hmin)
            packed_state[jk,:,3:6]=packed_state[jk,:,3:6]*lastheight[jk]
        select_ori_item=torch.squeeze(select_ori_item)
        if(i<context_size):
            packed_state[templist,i,0:3]=select_ori_item.clone().detach()
            packed_state[templist,i,3]=xidx.float()
            packed_state[templist,i,4]=yidx.float()
            packed_state[templist,i,5]=zidx.float()
        else:
            packed_state[templist,0:context_size-1,:]=packed_state[templist,1:context_size,:]
            packed_state[templist,context_size-1,0:3]=select_ori_item.clone().detach()
            packed_state[templist,context_size-1,3]=xidx.float()
            packed_state[templist,context_size-1,4]=yidx.float()
            packed_state[templist,context_size-1,5]=zidx.float()
        for jk in range(batch_size):
            packed_state[jk,:,3:6]=packed_state[jk,:,3:6]/height[jk]
        lastheight=height
        for jk in range(batch_size):
            l,w,h=orientation_transform(x[jk,seqidx[jk],0].item(),x[jk,seqidx[jk],1].item(),x[jk,seqidx[jk],2].item(),oriidx[jk].item())
            itKPpace=[xidx[jk].item(),yidx[jk].item(),zidx[jk].item(),xidx[jk].item()+l,yidx[jk].item()+w,zidx[jk].item()+h]
            deletelst=[]
            newKPlist=[]
            for i1 in range(len(KPlist[jk])):
                if(intersect_space(itKPpace,KPlist[jk][i1])):
                    lst3=solveES(itKPpace,KPlist[jk][i1],lwhmin)
                    deletelst.append(i1)
                    if(len(lst3)>0):
                        for j in lst3:
                            newKPlist.append(j)
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                KPlist[jk].remove(KPlist[jk][i1])
            deletelst=[]
            for i1 in range(len(newKPlist)-1):
                for j in range(i1+1,len(newKPlist)):
                    if(intersect_space(newKPlist[i1],newKPlist[j])):
                        a=overlaps_space(newKPlist[i1],newKPlist[j])
                        if(a==1):
                            if(i1 not in deletelst):
                                deletelst.append(i1)
                        elif(a==2):
                            if(j not in deletelst):
                                deletelst.append(j)
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                newKPlist.remove(newKPlist[i1])
            deletelst=[]
            for i1 in range(len(newKPlist)):
                for j in range(len(KPlist[jk])):
                    if(intersect_space(newKPlist[i1],KPlist[jk][j])):
                        a=overlaps_space(newKPlist[i1],KPlist[jk][j])
                        if(a>0):
                            if(i1 not in deletelst):
                                deletelst.append(i1)
                                break
            deletelst.sort(reverse=True)
            for i1 in deletelst:
                newKPlist.remove(newKPlist[i1])
            for i1 in newKPlist:
                KPlist[jk].append(i1)
    ru=total_volume/height  
    if(isTrain):
        # value=torch.squeeze(value)
        return pro,ru
    else:
        return ru
    
if __name__=='__main__':
    seed=0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    batch_size=64
    N=100
    lwhmax=50
    lwhmin=10
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    maxepoch=5000
    L=120
    W=100
    H=lwhmax*N
    lr=1e-4
    state_size=6
    hidden_size=128
    nb_heads=8
    encoder_nb_layers=3
    attn_span=20
    inner_hidden_size=512
    src_head_hidden_size=128
    pos_head_hidden_size=64
    s_res_size=1
    r_res_size=6
    x_res_size=L
    y_res_size=W
    decoder_nb_layers=1
    c_encoder_layers=3
    c_decoder_layers=1
    item_state_size=3
    context_size=min(N,10)
    MODEL=model.RCQL(state_size,hidden_size,nb_heads,encoder_nb_layers,attn_span,inner_hidden_size,src_head_hidden_size,pos_head_hidden_size,s_res_size,r_res_size,
           x_res_size,y_res_size,decoder_nb_layers,item_state_size)
    MODEL.set_device(device)
    MODELG=model.RCQL(state_size,hidden_size,nb_heads,encoder_nb_layers,attn_span,inner_hidden_size,src_head_hidden_size,pos_head_hidden_size,s_res_size,r_res_size,
           x_res_size,y_res_size,decoder_nb_layers,item_state_size)
    MODELG.set_device(device)
    torch.save(MODEL.state_dict(),f'3DBPP/lmperrcql_L{L}W{W}N{N}.pth')
    MODELG.load_state_dict(torch.load(f'3DBPP/lmperrcql_L{L}W{W}N{N}.pth'))
    optim=torch.optim.Adam(MODEL.parameters(),lr=lr)
    torch.autograd.set_detect_anomaly=True
    sche=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optim,T_max=maxepoch)
    start=time.time()
    for epoch in range(1,maxepoch+1):
        x=torch.randint(low=lwhmin,high=lwhmax+1,size=[batch_size,N,3])
        pro,cw_value=train_or_test(MODEL,x,L,W,H,lwhmax,attn_span,hidden_size,encoder_nb_layers,state_size,True,context_size)
        cw_value2=train_or_test(MODELG,x,L,W,H,lwhmax,attn_span,hidden_size,encoder_nb_layers,state_size,False,context_size)
        loss=(cw_value2-cw_value)*pro
        loss=torch.mean(loss)
        optim.zero_grad()
        loss.backward()
        optim.step()
        sche.step()
        print(time.time()-start)
        if(epoch%1==0):
            x=torch.randint(low=lwhmin,high=lwhmax+1,size=[batch_size,N,3])
            ru2=train_or_test(MODEL,x,L,W,H,lwhmax,attn_span,hidden_size,encoder_nb_layers,state_size,False,context_size)
            ru=train_or_test(MODELG,x,L,W,H,lwhmax,attn_span,hidden_size,encoder_nb_layers,state_size,False,context_size)
            t,p=ttest_rel(ru,ru2,alternative='less')
            if(t<0 and p<=0.05):
                torch.save(MODEL.state_dict(),f'3DBPP/lmperrcql_L{L}W{W}N{N}.pth')
                MODELG.load_state_dict(torch.load(f'3DBPP/lmperrcql_L{L}W{W}N{N}.pth'))
            end = time.time()
            print('epoch:',epoch,',greedy_ru:',format(torch.mean(ru).item()*100,'.2f'),',sample_ru:',format(torch.mean(ru2).item()*100,'.2f'),'%,time:',format(end-start,'.2f'),
                  ',loss:',format(loss.item(),'.2f'),',pro:',format(torch.mean(pro).item(),'.2f'),',t:',format(t,'.2f'),',p:',format(p*100,'.2f'),'%')
            start=time.time()


