import argparse
import numpy as np
import os
import sys
sys.path.append('../share')
sys.path.append('../model')

import torch
import torch.backends.cudnn as cudnn
from models_mage import MAGECityPolyGen
from dataloader import PolyDataset2D
from random_mask import random_masking_test

import cv2

def get_args_parser():
    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
    parser.add_argument('--batch_size', default=1, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=1, type=int)

    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    parser.add_argument('--num_workers', default=20, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    
    parser.add_argument('--stage2_model_path', type = str, default='../results/model/stage2_2d.pth')
    parser.add_argument('--save_path', type = str, default="../results/test/states_poly_reconstruct/")
    parser.add_argument('--data_path', type=str, default='../datasets/statespoly')
    parser.add_argument('--translation', type = float, default=0.5)
    
    parser.set_defaults(pin_mem=True)

    return parser


def main(args):

    device = torch.device(args.device)

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    if args.save_path is not None:
        os.makedirs(args.save_path, exist_ok=True)
        
    dataset_valid = PolyDataset2D(args.data_path, train=False,split_ratio = 0.8)

    data_loader_valid = torch.utils.data.DataLoader(
        dataset_valid, 
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    model = MAGECityPolyGen(num_heads=8, device=device,
                        depth=12, embed_dim=512, decoder_embed_dim=512,
                        decoder_depth=8, decoder_num_heads=8,
                        max_build=60, max_poly=20, discre=50)
    
    pretrained_model = torch.load(args.stage2_model_path)
    model.load_state_dict(pretrained_model)
    model.to(device)
    model.eval()


    remain_num = 6
    image_num = 0

    for valid_step, (samples, pos, info) in enumerate(data_loader_valid):
        if valid_step >=50:
            break

        poly_reserve, pos_reserve, poly_tar, pos_tar, len_tar, ids_keep = random_masking_test(samples, pos, info, remain_num, max_build = 60)

        poly_reserve = poly_reserve.to(device)
        pos_reserve = pos_reserve.to(device)
        poly_tar = poly_tar.to(device)
        pos_tar = pos_tar.to(device)

            
        img_t = np.ones((500,500,3),np.uint8)*255
        img_p = np.ones((500,500,3),np.uint8)*255

        for i in range(int(info[0, 0])):   
            pts = np.array(samples[0, i][:int(info[0][i+1]), :].cpu(), np.int32)
            pts = pts.reshape((-1,1,2)).astype(int)
            cv2.fillPoly(img_t, [pts], color=(238, 159, 153))
            cv2.polylines(img_t,[pts],True,(0,0,0),1)

        for id in ids_keep:
            pts = np.array(samples[0, id][:int(info[0][id+1]), :].cpu(), np.int32)
            pts = pts.reshape((-1,1,2)).astype(int)
            cv2.fillPoly(img_p, [pts], color=(238, 159, 153))
            cv2.polylines(img_p,[pts],True,(0,0,0),1)
                 

        with torch.no_grad():        
            img_p  = model.generate(poly_reserve, pos_reserve, pos_tar, len_tar, img_p)

        line = np.zeros((500, 5, 3), np.uint8)
        img = cv2.hconcat([img_t, line, img_p])
            
        image_num += 1
        print('img_num:', image_num)
                    
        cv2.imwrite(f'{args.save_path}'+str(image_num) +'.jpg',img)
  


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    main(args)