import argparse
import numpy as np
import sys
sys.path.append('../share')
sys.path.append('../model')

import torch
import torch.backends.cudnn as cudnn
from models_mage import MAGECityPolyGen
from models_pospred import MAGECityPosition
from function import pad_poly, in_poly_idx

import os
import cv2
from shapely.geometry import Polygon


def infgen(model, modelpos, samples, pos, sample_inter, remain_flag, device, discard_prob = 0.5, discre = 100, max_build = 250, use_sample=False,finetune = False):

    inlen = samples.shape[1]  
    samples_iter = samples.clone()
    pos_iter = pos.clone()

    remain_flag = remain_flag.flatten(0,1)
        

    endflag = 0
    gen_iter = []

    num_poly = inlen
    
    for npoly in range(max_build - inlen):
        pred = modelpos(samples_iter, pos_iter, None, generate = True)
        prob_pred = torch.sigmoid(pred[0, :])
        prob_pred = torch.where(prob_pred < discard_prob, torch.tensor(0.0).to(device), prob_pred) 
        while True:
            prob_pred = prob_pred*remain_flag
            if torch.max(prob_pred) < discard_prob:
                endflag = 1
                break      
            if use_sample == False:
                idx_iter = torch.argmax(prob_pred)
            else:            
                idx_iter = torch.multinomial(prob_pred, 1).squeeze(0)
                                
            remain_flag[idx_iter] = 0

            pred_pos = torch.cat([idx_iter.unsqueeze(0)//discre, idx_iter.unsqueeze(0)%discre],dim=0).unsqueeze(0).unsqueeze(0)
            predpoly = model.infgen(samples_iter, pos_iter, pred_pos)[0]

            polygon = Polygon(np.array(predpoly[0].clone().detach().cpu()))
            intersect_flag = 0
            for pe in range(samples_iter.shape[1]):
                polyexist = [] 
                for k in range(samples_iter.shape[2]):
                    if samples_iter[0, pe, k, 0] != 0:
                        point = samples_iter[0, pe, k, :].clone().detach().cpu().numpy()
                        polyexist.append(point)
                polyexist = Polygon(polyexist)
                if polygon.intersects(polyexist):
                    intersect_flag = 1
                    break

            for polyexist in sample_inter:
                polyexist = Polygon(polyexist.cpu().numpy())
                if polygon.intersects(polyexist):
                    intersect_flag = 1
                    break

            if intersect_flag == 0:
                break

        if endflag == 1:
            break
        
        poly_add, pos_add = pad_poly(predpoly)
        poly_add = poly_add.to(device)
        pos_add = pos_add.to(device)
        samples_iter = torch.cat([samples_iter, poly_add], dim = 1).detach()
        pos_iter = torch.cat([pos_iter, pos_add], dim = 1).detach()
        num_poly+=1

        gen_iter.append(predpoly.squeeze(0).detach())
        remain_flag = remain_flag*in_poly_idx(predpoly.squeeze(0).detach().cpu(), discre=discre).to(device)

    if finetune:
        gen_iter = model.infgen(samples, pos, pos_iter[:, inlen:])
        gen_iter = [genpoly.squeeze(0).detach() for genpoly in gen_iter]

    return gen_iter


def get_args_parser():
    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
    parser.add_argument('--batch_size', default=1, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=1, type=int)

    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    parser.add_argument('--num_workers', default=20, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    
    parser.add_argument('--stage1_model_path', type = str, default='../results/model/stage1_2d_1000meter.pth')
    parser.add_argument('--stage2_model_path', type = str, default='../results/model/stage2_2d_1000meter.pth')
    parser.add_argument('--scene_path', type = str, default='../datasets/scene/poly_2000_scene.npy')
    parser.add_argument('--save_path', type = str, default="../results/test/states_poly_city_complement/")
    
    parser.set_defaults(pin_mem=True)
    return parser

def main(args):
    device = torch.device(args.device)

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    modelpos = MAGECityPosition(embed_dim=512, depth=6, num_heads=8, 
                                 decoder_embed_dim=16, decoder_depth=3,   
                                 decoder_num_heads=8, discre = 100, patch_size = 10, patch_num = 10, 
                                 device = device, ablation = False, patchify = True)
    
    pretrained_modelpos = torch.load(args.stage1_model_path)
    modelpos.load_state_dict(pretrained_modelpos, strict = False)
    modelpos.to(device)
    modelpos.eval()
    
    model = MAGECityPolyGen(num_heads=8, device=device,
                        depth=12, embed_dim=512, decoder_embed_dim=512,
                        decoder_depth=8, decoder_num_heads=8,
                        max_build=250, max_poly=20, discre=100)
    
    pretrained_model = torch.load(args.stage2_model_path)
    model.load_state_dict(pretrained_model)
    model.to(device)
    model.eval()
    
    sences = np.load(args.scene_path, allow_pickle=True)

    image_num = 0
    
    assert args.batch_size == 1


    for i in range(200):
        scence = sences[i]
        sample_draw_in = []
        sample_draw = []
        for poly in scence:
            if ((500<poly[:, 0])*(poly[:, 0]<1500)*(500<poly[:, 1])*(poly[:, 1]<1500)).all():
                sample_draw_in.append(poly)
            else:
                sample_draw.append(poly)


        img = np.ones((2000,2000,3),np.uint8)*255

        for num, pts in enumerate(sample_draw):
            pts = pts.reshape((-1,1,2)).astype(int)
            cv2.fillPoly(img, [pts], color=(238, 159, 153))
            cv2.polylines(img,[pts],True,(0,0,0),1)
        for num, pts in enumerate(sample_draw_in):
            pts = pts.reshape((-1,1,2)).astype(int)
            cv2.fillPoly(img, [pts], color=(0, 255, 0))
            cv2.polylines(img,[pts],True,(0,0,0),1)
        
        dir_path = args.save_path
        if not os.path.exists(dir_path):
                os.makedirs(dir_path)
        image_num += 1
        print(image_num)
                    
        cv2.imwrite(f'{dir_path}'+str(image_num) +'.jpg',img)

        leng_prev = len(sample_draw)
        x, y = 500, 500
        
        start_point = [(0, 0),
                    (1000, 0),
                    (0, 1000),
                    (1000, 1000),
                    (x, y)]

        sample_draw = [torch.tensor(sample) for sample in sample_draw]

        for k in range(0,5):
            sx, sy = start_point[k]
            remain_flag = torch.ones([100,100]).to(device)
            if x-sx >= 0:    
                remain_flag[:(x-sx)//10, :] = 0
            else:
                remain_flag[(x-sx)//10:, :] = 0     
            if y-sy >= 0:    
                remain_flag[:, :(y-sy)//10] = 0
            else:
                remain_flag[:, (y-sy)//10:] = 0
                
            sample_in = []
            sample_inter = []
            for poly in sample_draw:
                if ((sx<poly[:, 0])*(poly[:, 0]<sx+1000)*(sy<poly[:, 1])*(poly[:, 1]<sy + 1000)).all():
                    poly_ = poly.clone()
                    poly_[:, 0] -= sx
                    poly_[:, 1] -= sy
                    sample_in.append(poly_)
                elif ((sx<poly[:, 0])*(poly[:, 0]<sx+1000)*(sy<poly[:, 1])*(poly[:, 1]<sy + 1000)).any():
                    poly_ = poly.clone()
                    poly_[:, 0] -= sx
                    poly_[:, 1] -= sy
                    sample_inter.append(poly_)

            sample_in, pos_in = pad_poly(sample_in)
            if sample_in.shape[1]>=250:
                continue

            sample_in = sample_in.to(device)
            pos_in = pos_in.to(device)
            gen_poly = infgen(model = model, modelpos = modelpos, samples = sample_in, pos = pos_in, remain_flag = remain_flag, sample_inter = sample_inter, device = device, use_sample=False)
            for polybulid in gen_poly:
                polydraw = polybulid.clone()
                polydraw[:, 0] += sx
                polydraw[:, 1] += sy
                sample_draw.append(polydraw)
        
            img = np.ones((2000,2000,3),np.uint8)*255

            for num, poly in enumerate(sample_draw):
                pts = np.array(poly.cpu(), np.int32)
                pts = pts.reshape((-1,1,2)).astype(int)
                if num<leng_prev:
                    cv2.fillPoly(img, [pts], color=(238, 159, 153))
                    cv2.polylines(img,[pts],True,(0,0,0),1)
                else:
                    cv2.fillPoly(img, [pts], color=(255, 255, 0))
                    cv2.polylines(img,[pts],True,(0,0,0),1)
            
                        
            cv2.imwrite(f'{dir_path}'+str(image_num) +f'_{k}.jpg',img)


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    main(args)

