import argparse
import numpy as np
import os
import sys
sys.path.append('../share')
sys.path.append('../model')


import torch
import torch.backends.cudnn as cudnn

from models_mage import MAGECityPolyGen3D
from models_pospred import MAGECityPosition3D
from function import pad_poly, in_poly_idx

import cv2
from shapely.geometry import Polygon


def infgen(model, modelpos, device, discard_prob = 0.5, use_sample=False, discre = 50, max_build = 60):

    remain_flag = torch.ones(discre*discre).to(device)
        
    endflag = 0
    gen_iter = []
    h = []

    num_poly = 0
    
    for npoly in range(max_build):
        if npoly == 0:     
            idx_iter = torch.randint(discre*discre, (1,)).squeeze()
                                
            remain_flag[idx_iter] = 0

            pred_pos = torch.cat([idx_iter.unsqueeze(0)//discre, idx_iter.unsqueeze(0)%discre],dim=0).unsqueeze(0).unsqueeze(0)
            predpoly, hpred = model.genfirst(pred_pos)
            predpoly = predpoly[0]
            
            poly_add, pos_add = pad_poly(predpoly)
            poly_add = poly_add.to(device)
            pos_add = pos_add.to(device)
            samples_iter = poly_add.detach()
            pos_iter = pos_add.detach()
            h_iter = hpred.unsqueeze(0).to(device).detach()
            num_poly+=1

            gen_iter.append(predpoly.squeeze(0).detach())
            h.append(hpred.detach())
            remain_flag = remain_flag*in_poly_idx(predpoly.squeeze(0).detach().cpu()).to(device)
        else:
            pred = modelpos(samples_iter, pos_iter, h_iter, None, generate = True)
            prob_pred = torch.sigmoid(pred[0, :])
            prob_pred = torch.where(prob_pred < 0.5, torch.tensor(0.0).to(device), prob_pred) 
            while True:
                prob_pred = prob_pred*remain_flag
                if torch.max(prob_pred) < discard_prob:
                    endflag = 1
                    break      
                if use_sample == False:
                    idx_iter = torch.argmax(prob_pred)
                else:            
                    idx_iter = torch.multinomial(prob_pred, 1).squeeze(0)
                                    
                remain_flag[idx_iter] = 0

                pred_pos = torch.cat([idx_iter.unsqueeze(0)//discre, idx_iter.unsqueeze(0)%discre],dim=0).unsqueeze(0).unsqueeze(0)
                predpoly, hpred = model.infgen(samples_iter, pos_iter, h_iter, pred_pos)
                predpoly = predpoly[0]

                polygon = Polygon(np.array(predpoly[0].clone().detach().cpu()))
                intersect_flag = 0
                for pe in range(samples_iter.shape[1]):
                    polyexist = [] 
                    for k in range(samples_iter.shape[2]):
                        if samples_iter[0, pe, k, 0] != 0:
                            point = samples_iter[0, pe, k, :].clone().detach().cpu().numpy()
                            polyexist.append(point)
                    polyexist = Polygon(polyexist)
                    if polygon.intersects(polyexist):
                        intersect_flag = 1
                        break

                if intersect_flag == 0:
                    break

            if endflag == 1:
                break
            
            poly_add, pos_add = pad_poly(predpoly)
            poly_add = poly_add.to(device)
            pos_add = pos_add.to(device)
            samples_iter = torch.cat([samples_iter, poly_add], dim = 1).detach()
            pos_iter = torch.cat([pos_iter, pos_add], dim = 1).detach()
            h_iter = torch.cat([h_iter, hpred.unsqueeze(0).to(device)], dim = 1).detach()
            num_poly+=1

            gen_iter.append(predpoly.squeeze(0).detach())
            h.append(hpred.detach())
            remain_flag = remain_flag*in_poly_idx(predpoly.squeeze(0).detach().cpu()).to(device)

    return gen_iter, h


def obj_gen(polys, h, num, path= ''):
    vertices = []
    face = []
    num_vert = 0
    for i, poly in enumerate(polys):
        
        leng = len(poly)
        for k in range(leng):
            vertices.append(list(poly[k])+[0])
            vertices.append(list(poly[k])+[h[i].item()])
            if k == 0:
                face.append([1+num_vert, 2+num_vert, 2*leng+num_vert, 2*leng-1+num_vert])
            else:
                face.append([2*k+1+num_vert, 2*k+2+num_vert, 2*k+num_vert, 2*k-1+num_vert])
        face.append([2*j+1+num_vert for j in range(leng)])
        face.append([2*j+2+num_vert for j in range(leng)])
        num_vert += 2*leng

    file = open(path+str(num)+'.obj', "w")
    for v in vertices:
        file.write("v "+f"{v[0]}"+" "+f"{v[1]}"+" "+f"{v[2]}")
        file.write("\n")
    for f in face:
        face = "f "
        for id in f:
            face += f"{id}"+" "
        file.write(face)
        file.write("\n")   

def get_args_parser():
    parser = argparse.ArgumentParser('testing', add_help=False)
    parser.add_argument('--batch_size', default=1, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=1, type=int)

    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    parser.add_argument('--num_workers', default=20, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')

    parser.add_argument('--split_ratio', type = float, default=0.8)

    parser.add_argument('--stage1_model_path', type = str, default='../results/model/stage1_3d.pth')
    parser.add_argument('--stage2_model_path', type = str, default='../results/model/stage2_3d.pth')
    parser.add_argument('--save_path', type = str, default="../results/test/states_poly_gen_3d/")
    
    parser.set_defaults(pin_mem=True)
    return parser

def main(args):
    device = torch.device(args.device)

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True
    
    modelpos = MAGECityPosition3D(embed_dim=256, depth=6, num_heads=8, 
                                 decoder_embed_dim=16, decoder_depth=3,   
                                 decoder_num_heads=8, discre = 50, patch_size = 5, patch_num = 10, 
                                 device = device)
    
    pretrained_modelpos = torch.load(args.stage1_model_path)
    modelpos.load_state_dict(pretrained_modelpos)
    modelpos.to(device)
    modelpos.eval()
    
    model = MAGECityPolyGen3D(num_heads=8, device=device,
                        depth=12, embed_dim=512, decoder_embed_dim=512,
                        decoder_depth=8, decoder_num_heads=8,
                        max_build=60, max_poly=20, discre=50)
    
    pretrained_model = torch.load(args.stage2_model_path)
    model.load_state_dict(pretrained_model)
    model.to(device)
    model.eval()

    image_num = 0
    
    assert args.batch_size == 1
    use_s = False

    for t in range(100):
        gen_poly_ini, h_pred = infgen(model = model, modelpos = modelpos, device = device, use_sample=use_s)
        
        img = np.ones((500,500,3),np.uint8)*255

        edgs = 0
        gen_poly = []
        for poly in gen_poly_ini:
            if Polygon(poly.cpu()).is_valid:
                gen_poly.append(poly)
                
        for num, poly in enumerate(gen_poly):
            pts = np.array(poly.cpu(), np.int32)
            edgs+=pts.shape[0]
            pts = pts.reshape((-1,1,2)).astype(int)

            cv2.fillPoly(img, [pts], color=(255, 255, 0))
            cv2.polylines(img,[pts],True,(0,0,0),1)

        dir_path_obj = args.save_path +"obj/"
        if not os.path.exists(dir_path_obj):
            os.makedirs(dir_path_obj)

        dir_path_img = args.save_path +"img/"
        if not os.path.exists(dir_path_img):
            os.makedirs(dir_path_img)
        image_num += 1
        print(image_num)
    
        obj_gen(gen_poly, h_pred, t, path = dir_path_obj)
                
        cv2.imwrite(f'{dir_path_img}'+str(t) +'.jpg',img)


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    main(args)

