import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.append('../share')
from util.transformer import Block
from poly_embed import PolyEmbed
from models_autoregress import AutoPoly
from shapely.geometry import Polygon

import cv2

class MAGECityPolyGen(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, embed_dim=512, depth=12, num_heads=8, decoder_embed_dim=512, decoder_depth=8, 
                 decoder_num_heads=8, mlp_ratio=4., drop_ratio = 0.1, pos_weight = 20,
                 max_poly=20, max_build = 60, discre = 50, device = 'cuda', fix_mask_token = False, 
                 norm_layer=nn.LayerNorm):
        super().__init__()

        self.embed_dim = embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.device = device
        self.max_poly = max_poly
        self.discre = discre
        self.max_build = max_build
        self.fix_mask_token = fix_mask_token

        self.num_heads = num_heads

        self.fc_embedding = PolyEmbed(ouput_dim=embed_dim, device = device)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim if self.fix_mask_token else embed_dim))

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        
        self.decoder_pred = nn.Linear(decoder_embed_dim, decoder_embed_dim, bias=True)

        self.midlossfc = nn.Linear(decoder_embed_dim, 2, bias=True)
        # --------------------------------------------------------------------------

        self.automodel = AutoPoly(latent_dim = decoder_embed_dim, device = device)

        self.mseloss = nn.MSELoss()
        self.bceloss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))

        self.initialize_weights()

    def initialize_weights(self):
        torch.nn.init.normal_(self.mask_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def pos_embed_cxy(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        position = pos.cpu().numpy()
        emb_h = self.get_1d_embed(embed_dim // 2, position[:, :, 0])  # (H*W, D/2)
        emb_w = self.get_1d_embed(embed_dim // 2, position[:, :, 1])  # (H*W, D/2)

        emb = np.concatenate([emb_h, emb_w], axis=2) # (H*W, D)
        emb = torch.tensor(emb).to(self.device)
        return emb

    def get_1d_embed(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega  # (D/2,)
        batch_n, num_b = pos.shape

        pos = pos.reshape(-1)  # (M,)
        out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        out = out.reshape(batch_n, num_b, embed_dim // 2)

        emb_sin = np.sin(out) # (M, D/2)
        emb_cos = np.cos(out) # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=2)  # (M, D)
        return emb
  

    def forward_encoder(self, x, pos):
        bsz = x.shape[0]

        x= F.relu(self.fc_embedding(x.flatten(0,1))).view(bsz, -1, self.embed_dim)

        x = x + self.pos_embed_cxy(self.embed_dim, pos)

        x = torch.cat([self.mask_token.repeat(x.shape[0], 1, 1), x], dim = 1)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x

    def forward_decoder(self, x, posall):
        x = F.relu(self.decoder_embed(x))

        if self.fix_mask_token:
            mask_tokens = self.mask_token.repeat(x.shape[0], 1, 1)
        else:
            mask_tokens = x[:, 0:1, :]
        x_ = x[:, 1:, :]

        mask_tokens = mask_tokens.repeat(1, self.max_build-x_.shape[1], 1)
        x_ = torch.cat([x_, mask_tokens], dim = 1)   

        x_ = x_ + self.pos_embed_cxy(self.decoder_embed_dim, posall)

        for blk in self.decoder_blocks:
            x_ = blk(x_)

        x_ = self.decoder_norm(x_)
        out = self.decoder_pred(x_)

        return out

    def compute_loss(self, out, polyin, len_tar):
        hyp_bsz = out.shape[0]
        
        poly_out = out[:, :, :2]
        poly_len = out[:, :, 2]
        loss_l1 = self.mseloss(torch.cat([poly_out[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0),
                               torch.cat([polyin[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0))

        poly_len_tar = torch.zeros(poly_len.shape).scatter_(1, len_tar.unsqueeze(-1), torch.ones(len(len_tar), 1)).to(self.device)
        loss_len = self.bceloss(poly_len, poly_len_tar)

        return loss_l1, loss_len
    
    
    def add_random_noise(self, pos, poly):
        shape = pos.shape

        probabilities = torch.rand(shape)

        prob_0_1_mask = (probabilities < 0.1).float()
        noise_0_1 = (torch.randint(0, 2, size=shape).float() * 2 - 1)*2
        noise_0_1 = noise_0_1 * prob_0_1_mask

        prob_0_4_mask = ((probabilities >= 0.1) & (probabilities < 0.5)).float()
        noise_0_4 = torch.randint(0, 2, size=shape).float() * 2 - 1
        noise_0_4 = noise_0_4 * prob_0_4_mask

        prob_0_5_mask = (probabilities >= 0.5).float()
        noise_0_5 = torch.zeros(shape)
        noise_0_5 = noise_0_5 * prob_0_5_mask

        noise = (noise_0_1 + noise_0_4 + noise_0_5).to(self.device)

        noisy_pos = torch.clamp(pos + noise, 0, self.discre-1)
        noisy_poly = torch.clamp(poly + noise.unsqueeze(2).repeat(1,1,self.max_poly, 1)*(poly != 0), 0, 500)
        return noisy_pos, noisy_poly

    def forward(self, poly, pos, postar, polytar, len_tar, noise = False):
        bsz, remain_num, _, _ = poly.shape
        if noise:
            postar, polytar = self.add_random_noise(postar, polytar)

        latent = self.forward_encoder(poly, pos)

        posall = torch.cat([pos, postar], dim = 1) 
        pred_latent = self.forward_decoder(latent, posall) 

        latentautoin = torch.cat([pred_latent[i, remain_num:remain_num+len(len_tar[i])] for i in range(bsz)], dim = 0)
        polyautoin = torch.cat([polytar[i, :len(len_tar[i])] for i in range(bsz)], dim = 0)

        out = self.automodel(latentautoin, polyautoin)

        len_tar = torch.cat([len_tar[i] for i in range(bsz)], dim = 0)
        loss_l1, loss_len = self.compute_loss(out, polyautoin, len_tar) 

        return loss_l1, loss_len, out

    def infgen(self, poly, pos, postar):
        bsz, remain_num, _, _ = poly.shape
        assert bsz == 1
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, self.max_build-remain_num, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar

        latent = self.forward_encoder(poly, pos)

        posall = torch.cat([pos, postarin], dim = 1)
        assert posall.shape[1]==self.max_build
        pred_latent = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, remain_num:remain_num+postar_len]
        output_poly = []
        for k in range(latentin.shape[0]): 
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = out[0, -1, :2].unsqueeze(0).unsqueeze(0)

            for _ in range(self.max_poly-1):
                out = self.automodel(latentin_, inputpoly)
                if torch.sigmoid(out[0, -1, 2])>0.5 and inputpoly.shape[1]>2:
                    break
                inputpoly = torch.cat([inputpoly, out[0, -1, :2].unsqueeze(0).unsqueeze(0)], dim = 1)
            output_poly.append(inputpoly)

        return output_poly
                
    def genfirst(self, postar):
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, self.max_build, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar

        x = self.mask_token

        for blk in self.blocks:
            x = blk(x)
        latent = self.norm(x)

        posall = postarin
        assert posall.shape[1]==self.max_build
        pred_latent = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, 0:postar_len]
        output_poly = []
        for k in range(latentin.shape[0]): 
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = out[0, -1, :2].unsqueeze(0).unsqueeze(0)

            for _ in range(self.max_poly-1):
                out = self.automodel(latentin_, inputpoly)
                if torch.sigmoid(out[0, -1, 2])>0.5 and inputpoly.shape[1]>2:
                    break
                inputpoly = torch.cat([inputpoly, out[0, -1, :2].unsqueeze(0).unsqueeze(0)], dim = 1)
            output_poly.append(inputpoly)

        return output_poly


    def generate(self, poly, pos, postar, len_tar, img_p):
        bsz, remain_num, _, _ = poly.shape
        assert bsz == 1

        latent = self.forward_encoder(poly, pos)

        posall = torch.cat([pos, postar], dim = 1)
        pred_latent = self.forward_decoder(latent, posall) 

        latentautoin = pred_latent[0, remain_num:remain_num+len(len_tar[0])]
        for i in range(latentautoin.shape[0]):
            latentin = latentautoin[i:i+1]
            point = []
            
            out = self.automodel(latentin, gen = True)
            inputpoly = out[0, -1, :2].unsqueeze(0).unsqueeze(0)
            point.append(np.array(out[0, -1, :2].cpu()))
            step = 1
            for i in range(19):
                out = self.automodel(latentin, inputpoly)
                if torch.sigmoid(out[0, -1, 2])>0.5 and inputpoly.shape[1]>2:
                    break
                inputpoly = torch.cat([inputpoly, out[0, -1, :2].unsqueeze(0).unsqueeze(0)], dim = 1)
                point.append(np.array(out[0, -1, :2].cpu()))
                step += 1

            pts = np.array(point, np.int32)
            pts = pts.reshape((-1,1,2)).astype(int)
            cv2.fillPoly(img_p, [pts], color=(255, 255, 0))
            cv2.polylines(img_p,[pts],True,(0,0,0),1)

        return img_p
    
    
class MAGECityPolyGenRoad(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, embed_dim=512, depth=12, num_heads=8, decoder_embed_dim=512, decoder_depth=8, 
                 decoder_num_heads=8, mlp_ratio=4., drop_ratio = 0.1, pos_weight = 20,
                 max_poly=20, max_build = 60, discre = 50, device = 'cuda', max_road_len = 38, 
                 norm_layer=nn.LayerNorm, append_road = True):
        super().__init__()
        self.append = append_road
        self.embed_dim = embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.device = device
        self.max_poly = max_poly
        self.discre = discre
        self.max_build = max_build

        self.num_heads = num_heads

        self.fc_embedding = PolyEmbed(ouput_dim=embed_dim, device = device)
        self.road_embedding = PolyEmbed(ouput_dim=embed_dim, max_position_embeddings=max_road_len, device = device)

        self.road_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.road_decoder_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_embed_road = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        
        self.decoder_pred = nn.Linear(decoder_embed_dim, decoder_embed_dim, bias=True)
        self.midlossfc = nn.Linear(decoder_embed_dim, 2, bias=True)
        self.automodel = AutoPoly(latent_dim = decoder_embed_dim, device = device)

        self.mseloss = nn.MSELoss()
        self.bceloss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))

        self.initialize_weights()

    def initialize_weights(self):
        torch.nn.init.normal_(self.mask_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def pos_embed_cxy(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        position = pos.cpu().numpy()
        emb_h = self.get_1d_embed(embed_dim // 2, position[:, :, 0])  # (H*W, D/2)
        emb_w = self.get_1d_embed(embed_dim // 2, position[:, :, 1])  # (H*W, D/2)

        emb = np.concatenate([emb_h, emb_w], axis=2) # (H*W, D)
        emb = torch.tensor(emb).to(self.device)
        return emb

    def get_1d_embed(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega  # (D/2,)
        batch_n, num_b = pos.shape

        pos = pos.reshape(-1)  # (M,)
        out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        out = out.reshape(batch_n, num_b, embed_dim // 2)

        emb_sin = np.sin(out) # (M, D/2)
        emb_cos = np.cos(out) # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=2)  # (M, D)
        return emb
  

    def forward_encoder(self, x, pos, road):
        bsz, len_build = x.shape[:2]

        x= F.relu(self.fc_embedding(x.flatten(0,1))).view(bsz, -1, self.embed_dim)
        x_road= F.relu(self.road_embedding(road.flatten(0,1))).view(bsz, -1, self.embed_dim)

        x = x + self.pos_embed_cxy(self.embed_dim, pos)
        x_road = x_road + self.road_token.repeat(bsz, x_road.shape[1], 1)

        x = torch.cat([self.mask_token.repeat(x.shape[0], 1, 1), x, x_road], dim = 1)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        if self.append:
            return x[:, :1+len_build], x[:, 1+len_build:]
        else:
            return x[:, :1+len_build]

    def forward_decoder(self, x, posall, latent =  None):
        x = F.relu(self.decoder_embed(x))
        if self.append:
            x_road = F.relu(self.decoder_embed_road(latent))

        mask_tokens = x[:, 0:1, :]

        x_ = x[:, 1:, :]

        mask_tokens = mask_tokens.repeat(1, self.max_build-x_.shape[1], 1)
        x_ = torch.cat([x_, mask_tokens], dim = 1)   

        x_ = x_ + self.pos_embed_cxy(self.decoder_embed_dim, posall)

        if self.append:
            len_tem = x_.shape[1]
            x_road = x_road + self.road_decoder_token.repeat(x_road.shape[0], x_road.shape[1], 1)
            x_ = torch.cat([x_, x_road], dim = 1)

        for blk in self.decoder_blocks:
            x_ = blk(x_)

        x_ = self.decoder_norm(x_)
        if self.append:
            out = self.decoder_pred(x_[:, :len_tem])
        else:
            out = self.decoder_pred(x_)

        return out

    def compute_loss(self, out, polyin, len_tar):
        hyp_bsz = out.shape[0]
        
        poly_out = out[:, :, :2]
        poly_len = out[:, :, 2]
        loss_l1 = self.mseloss(torch.cat([poly_out[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0),
                               torch.cat([polyin[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0))

        poly_len_tar = torch.zeros(poly_len.shape).scatter_(1, len_tar.unsqueeze(-1), torch.ones(len(len_tar), 1)).to(self.device)
        loss_len = self.bceloss(poly_len, poly_len_tar)

        return loss_l1, loss_len
    

    def forward(self, poly, pos, postar, polytar, len_tar, road):
        bsz, remain_num, _, _ = poly.shape
        posall = torch.cat([pos, postar], dim = 1) 

        if self.append:
            latent, latent_road = self.forward_encoder(poly, pos, road)
            pred_latent = self.forward_decoder(latent, posall, latent_road) 
        else:
            latent = self.forward_encoder(poly, pos, road)
            pred_latent = self.forward_decoder(latent, posall) 

        latentautoin = torch.cat([pred_latent[i, remain_num:remain_num+len(len_tar[i])] for i in range(bsz)], dim = 0)
        polyautoin = torch.cat([polytar[i, :len(len_tar[i])] for i in range(bsz)], dim = 0)

        out = self.automodel(latentautoin, polyautoin)

        len_tar = torch.cat([len_tar[i] for i in range(bsz)], dim = 0)
        loss_l1, loss_len = self.compute_loss(out, polyautoin, len_tar) 

        return loss_l1, loss_len, out

    def infgen(self, poly, pos, postar, road):
        bsz, remain_num, _, _ = poly.shape
        assert bsz == 1
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, self.max_build-remain_num, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar
        posall = torch.cat([pos, postarin], dim = 1)
        assert posall.shape[1]==self.max_build

        if self.append:
            latent, latent_road = self.forward_encoder(poly, pos, road)
            pred_latent = self.forward_decoder(latent, posall, latent_road) 
        else:
            latent = self.forward_encoder(poly, pos, road)
            pred_latent = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, remain_num:remain_num+postar_len]
        output_poly = []
        for k in range(latentin.shape[0]): 
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = out[0, -1, :2].unsqueeze(0).unsqueeze(0)

            for i in range(self.max_poly-1):
                out = self.automodel(latentin_, inputpoly)
                if torch.sigmoid(out[0, -1, 2])>0.5 and inputpoly.shape[1]>2:
                    break
                inputpoly = torch.cat([inputpoly, out[0, -1, :2].unsqueeze(0).unsqueeze(0)], dim = 1)
            output_poly.append(inputpoly)

        return output_poly
                
                
class MAGECityPolyGenSample(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, embed_dim=512, depth=12, num_heads=8, decoder_embed_dim=512, decoder_depth=8, 
                 decoder_num_heads=8, mlp_ratio=4., drop_ratio = 0.1, pos_weight = 20,
                 max_poly=20, max_build = 60, discre = 50, device = 'cuda', quant = 250,
                 norm_layer=nn.LayerNorm):
        super().__init__()

        self.embed_dim = embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.device = device
        self.max_poly = max_poly
        self.discre = discre
        self.max_build = max_build
        self.quant = quant

        self.num_heads = num_heads

        self.fc_embedding = PolyEmbed(ouput_dim=embed_dim, device = device)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        
        self.decoder_pred = nn.Linear(decoder_embed_dim, decoder_embed_dim, bias=True)
        # --------------------------------------------------------------------------

        self.automodel = AutoPoly(latent_dim = decoder_embed_dim, out_dim = 2*self.quant+1, device = device)

        self.mseloss = nn.MSELoss()
        self.bceloss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        self.cross_entropy = nn.CrossEntropyLoss()

        self.initialize_weights()

    def initialize_weights(self):
        torch.nn.init.normal_(self.mask_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def pos_embed_cxy(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        position = pos.cpu().numpy()
        emb_h = self.get_1d_embed(embed_dim // 2, position[:, :, 0])  # (H*W, D/2)
        emb_w = self.get_1d_embed(embed_dim // 2, position[:, :, 1])  # (H*W, D/2)

        emb = np.concatenate([emb_h, emb_w], axis=2) # (H*W, D)
        emb = torch.tensor(emb).to(self.device)
        return emb

    def get_1d_embed(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega  # (D/2,)
        batch_n, num_b = pos.shape

        pos = pos.reshape(-1)  # (M,)
        out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        out = out.reshape(batch_n, num_b, embed_dim // 2)

        emb_sin = np.sin(out) # (M, D/2)
        emb_cos = np.cos(out) # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=2)  # (M, D)
        return emb
  

    def forward_encoder(self, x, pos):
        bsz = x.shape[0]

        x= F.relu(self.fc_embedding(x.flatten(0,1))).view(bsz, -1, self.embed_dim)

        x = x + self.pos_embed_cxy(self.embed_dim, pos)

        x = torch.cat([self.mask_token.repeat(x.shape[0], 1, 1), x], dim = 1)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x

    def forward_decoder(self, x, posall):
        x = F.relu(self.decoder_embed(x))

        mask_tokens = x[:, 0:1, :]
        x_ = x[:, 1:, :]

        mask_tokens = mask_tokens.repeat(1, self.max_build-x_.shape[1], 1)
        x_ = torch.cat([x_, mask_tokens], dim = 1)   

        x_ = x_ + self.pos_embed_cxy(self.decoder_embed_dim, posall)

        for blk in self.decoder_blocks:
            x_ = blk(x_)

        x_ = self.decoder_norm(x_)
        out = self.decoder_pred(x_)

        return out

    def compute_loss(self, out, polyin, len_tar):
        hyp_bsz = out.shape[0]
        
        poly_out = out[:, :, :2*self.quant]
        poly_len = out[:, :, 2*self.quant]
        
        poly_out = torch.cat([poly_out[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0)
        
        poly_tar = torch.cat([polyin[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0)
        poly_tar = (poly_tar.reshape(-1)/(500/self.quant)).to(int)
        loss_l1 = self.cross_entropy(poly_out.reshape(-1,  self.quant), poly_tar)

        poly_len_tar = torch.zeros(poly_len.shape).scatter_(1, len_tar.unsqueeze(-1), torch.ones(len(len_tar), 1)).to(self.device)
        loss_len = self.bceloss(poly_len, poly_len_tar)

        return loss_l1, loss_len
    

    def forward(self, poly, pos, postar, polytar, len_tar):
        bsz, remain_num, _, _ = poly.shape

        latent = self.forward_encoder(poly, pos)

        posall = torch.cat([pos, postar], dim = 1) 
        pred_latent = self.forward_decoder(latent, posall) 

        latentautoin = torch.cat([pred_latent[i, remain_num:remain_num+len(len_tar[i])] for i in range(bsz)], dim = 0)
        polyautoin = torch.cat([polytar[i, :len(len_tar[i])] for i in range(bsz)], dim = 0)

        out = self.automodel(latentautoin, polyautoin)

        len_tar = torch.cat([len_tar[i] for i in range(bsz)], dim = 0)
        loss_l1, loss_len = self.compute_loss(out, polyautoin, len_tar) 

        return loss_l1, loss_len, out

    def infgen(self, poly, pos, postar, random_ = False, k_beam = 5):
        bsz, remain_num, _, _ = poly.shape
        assert bsz == 1
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, self.max_build-remain_num, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar

        latent = self.forward_encoder(poly, pos)

        posall = torch.cat([pos, postarin], dim = 1)
        assert posall.shape[1]==self.max_build
        pred_latent = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, remain_num:remain_num+postar_len]
        output_poly = []
        
        for k in range(latentin.shape[0]):     
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = []
            prob = []
            endflag = torch.zeros([k_beam])
            predvertexprob = out[0,-1,:2*self.quant].reshape(2,self.quant).softmax(-1)
            for j in range(k_beam):           
                index = torch.argmax(predvertexprob, dim = -1)
                prob.append(torch.sqrt(predvertexprob[0, index[0]]*predvertexprob[1, index[1]]))
                inputpoly.append(index.unsqueeze(0).unsqueeze(0).float()*(500/self.quant))
                predvertexprob[0, index[0]] = 0
                predvertexprob[1, index[1]] = 0
            for i in range(19):
                outtem = torch.zeros([k_beam, k_beam, 2])
                probtem = torch.zeros([k_beam, k_beam])
                for k in range(k_beam):
                    if endflag[k] == 0:
                        out = self.automodel(latentin_, inputpoly[k])
                    
                        if torch.sigmoid(out[0, -1, 2*self.quant])>0.5 and inputpoly[k].shape[1]>2:
                            endflag[k] = 1
                            continue

                        predvertexprob = out[0,-1,:2*self.quant].reshape(2,self.quant).softmax(-1)
                        for indpoly in inputpoly[k][0]:
                            predvertexprob[0, int(indpoly[0]/2)] = 0
                            predvertexprob[1, int(indpoly[1]/2)] = 0
                        for r in range(k_beam):           
                            index = torch.argmax(predvertexprob, dim = -1)
                            probtem[k, r] = torch.sqrt(torch.sqrt(predvertexprob[0, index[0]]*predvertexprob[1, index[1]])*prob[k])
                            outtem[k, r] = index
                            predvertexprob[0, index[0]] = 0
                            predvertexprob[1, index[1]] = 0
                probtem = probtem.flatten(0,1)
                outtem = outtem.flatten(0,1)
                inputpolynew = []
                probnew = []
                for w in range(k_beam):
                    if endflag[w] == 1:
                        inputpolynew.append(inputpoly[w])
                        probnew.append(prob[w])
                        continue
                    maxid = torch.argmax(probtem)
                    inputpolynew.append(torch.cat([inputpoly[maxid//k_beam], outtem[maxid].unsqueeze(0).unsqueeze(0).to(self.device)*(500/self.quant)], dim = 1))
                    probnew.append(probtem[maxid].clone())
                    probtem[maxid] = 0
                inputpoly = inputpolynew
                prob = probnew

            if random_:
                output_poly.append(inputpoly[torch.randint(0, k_beam, (1,))])

            else:
                for pid in range(k_beam):
                    poly = inputpoly[torch.argmax(torch.tensor(prob))]
                    if Polygon(np.array(poly[0].detach().cpu())).is_valid or pid == k_beam-1:
                        output_poly.append(poly)
                        break
                    else:
                        prob[torch.argmax(torch.tensor(prob))] = 0
                

        return output_poly
                
    def genfirst(self, postar, k_beam = 5):
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, self.max_build, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar

        x = self.mask_token

        for blk in self.blocks:
            x = blk(x)
        latent = self.norm(x)

        posall = postarin
        assert posall.shape[1]==self.max_build
        pred_latent = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, 0:postar_len]
        output_poly = []
        
        for k in range(latentin.shape[0]): 
            
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = []
            prob = []
            endflag = torch.zeros([k_beam])
            predvertexprob = out[0,-1,:2*self.quant].reshape(2,self.quant).softmax(-1)
            
            for j in range(k_beam):           
                index = torch.argmax(predvertexprob, dim = -1)
                prob.append(torch.sqrt(predvertexprob[0, index[0]]*predvertexprob[1, index[1]]))
                inputpoly.append(index.unsqueeze(0).unsqueeze(0).float()*(500/self.quant))
                predvertexprob[0, index[0]] = 0
                predvertexprob[1, index[1]] = 0
            for i in range(19):
                outtem = torch.zeros([k_beam, k_beam, 2])
                probtem = torch.zeros([k_beam, k_beam])
                for k in range(k_beam):
                    if endflag[k] == 0:
                        out = self.automodel(latentin_, inputpoly[k])
                    
                        if torch.sigmoid(out[0, -1, 2*self.quant])>0.5 and inputpoly[k].shape[1]>2:
                            endflag[k] = 1
                            continue

                        predvertexprob = out[0,-1,:2*self.quant].reshape(2,self.quant).softmax(-1)
                        for indpoly in inputpoly[k][0]:
                            predvertexprob[0, int(indpoly[0]/2)] = 0
                            predvertexprob[1, int(indpoly[1]/2)] = 0
                        for r in range(k_beam):           
                            index = torch.argmax(predvertexprob, dim = -1)
                            probtem[k, r] = torch.sqrt(torch.sqrt(predvertexprob[0, index[0]]*predvertexprob[1, index[1]])*prob[k])
                            outtem[k, r] = index
                            predvertexprob[0, index[0]] = 0
                            predvertexprob[1, index[1]] = 0
                probtem = probtem.flatten(0,1)
                outtem = outtem.flatten(0,1)
                inputpolynew = []
                probnew = []
                for w in range(k_beam):
                    if endflag[w] == 1:
                        inputpolynew.append(inputpoly[w])
                        probnew.append(prob[w])
                        continue
                    maxid = torch.argmax(probtem)
                    inputpolynew.append(torch.cat([inputpoly[maxid//k_beam], outtem[maxid].unsqueeze(0).unsqueeze(0).to(self.device)*(500/self.quant)], dim = 1))
                    probnew.append(probtem[maxid].clone())
                    probtem[maxid] = 0
                inputpoly = inputpolynew
                prob = probnew

            for pid in range(k_beam):
                poly = inputpoly[torch.argmax(torch.tensor(prob))]
                if Polygon(np.array(poly[0].detach().cpu())).is_valid or pid == k_beam-1:
                    output_poly.append(poly)
                    break
                else:
                    prob[torch.argmax(torch.tensor(prob))] = 0

        return output_poly

                

class MAGECityPolyGen3D(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, embed_dim=512, depth=12, num_heads=8, decoder_embed_dim=512, decoder_depth=8, 
                 decoder_num_heads=8, mlp_ratio=4., drop_ratio = 0.1, pos_weight = 20,
                 max_poly=20, max_build = 60, discre = 50, device = 'cuda',
                 norm_layer=nn.LayerNorm):
        super().__init__()
        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.embed_dim = embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.device = device
        self.max_poly = max_poly
        self.discre = discre
        self.max_build = max_build

        self.num_heads = num_heads

        self.fc_embedding = PolyEmbed(ouput_dim=embed_dim, device = device)
        self.h_embedding = nn.Linear(1, embed_dim)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_ratio=drop_ratio, attn_drop_ratio=drop_ratio, drop_path_ratio=drop_ratio)#, qk_scale=None
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        
        self.decoder_pred = nn.Linear(decoder_embed_dim, decoder_embed_dim, bias=True)
        self.h_pred = nn.Linear(decoder_embed_dim, 1)

        self.midlossfc = nn.Linear(decoder_embed_dim, 2, bias=True)
        # --------------------------------------------------------------------------

        self.automodel = AutoPoly(latent_dim = decoder_embed_dim, device = device)

        self.l1loss = nn.L1Loss()
        self.mseloss = nn.MSELoss()
        self.bceloss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))

        self.initialize_weights()

    def initialize_weights(self):
        torch.nn.init.normal_(self.mask_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def pos_embed_cxy(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        position = pos.cpu().numpy()
        emb_h = self.get_1d_embed(embed_dim // 2, position[:, :, 0])  # (H*W, D/2)
        emb_w = self.get_1d_embed(embed_dim // 2, position[:, :, 1])  # (H*W, D/2)

        emb = np.concatenate([emb_h, emb_w], axis=2) # (H*W, D)
        emb = torch.tensor(emb).to(self.device)
        return emb

    def get_1d_embed(self, embed_dim, pos):
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega  # (D/2,)
        batch_n, num_b = pos.shape

        pos = pos.reshape(-1)  # (M,)
        out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
        out = out.reshape(batch_n, num_b, embed_dim // 2)

        emb_sin = np.sin(out) # (M, D/2)
        emb_cos = np.cos(out) # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=2)  # (M, D)
        return emb
  

    def forward_encoder(self, x, pos, h):
        bsz = x.shape[0]

        x= F.relu(self.fc_embedding(x.flatten(0,1))+self.h_embedding(h.flatten(0,1))).view(bsz, -1, self.embed_dim)

        x = x + self.pos_embed_cxy(self.embed_dim, pos)

        x = torch.cat([self.mask_token.repeat(x.shape[0], 1, 1), x], dim = 1)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x

    def forward_decoder(self, x, posall):
        x = F.relu(self.decoder_embed(x))

        mask_tokens = x[:, 0:1, :]
        x_ = x[:, 1:, :]

        mask_tokens = mask_tokens.repeat(1, self.max_build-x_.shape[1], 1)
        x_ = torch.cat([x_, mask_tokens], dim = 1)   

        x_ = x_ + self.pos_embed_cxy(self.decoder_embed_dim, posall)

        for blk in self.decoder_blocks:
            x_ = blk(x_)

        x_ = self.decoder_norm(x_)
        out = self.decoder_pred(x_)
        h_pred = self.h_pred(x_)

        return out, h_pred

    def compute_loss(self, out, polyin, len_tar):
        hyp_bsz = out.shape[0]
        
        poly_out = out[:, :, :2]
        poly_len = out[:, :, 2]
        loss_l1 = self.mseloss(torch.cat([poly_out[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0),
                               torch.cat([polyin[i, :len_tar[i]] for i in range(hyp_bsz)], dim = 0))

        poly_len_tar = torch.zeros(poly_len.shape).scatter_(1, len_tar.unsqueeze(-1), torch.ones(len(len_tar), 1)).to(self.device)
        loss_len = self.bceloss(poly_len, poly_len_tar)

        return loss_l1, loss_len
    
    
    def forward(self, poly, pos, h, polytar, postar, htar, len_tar):
        ## len_tar object_numpy [[]]
        ## poly_tar [x.shape[0], max_build-remain_num, 20, 2]

        bsz, remain_num, _, _ = poly.shape

        latent = self.forward_encoder(poly, pos, h)

        posall = torch.cat([pos, postar], dim = 1)
        pred_latent, h_pred = self.forward_decoder(latent, posall) 

        loss_height = self.l1loss(torch.cat([h_pred[i, remain_num:remain_num+len(len_tar[i])] for i in range(bsz)], dim=0),
                                  torch.cat([htar[i, :len(len_tar[i])] for i in range(bsz)], dim=0))

        latentautoin = torch.cat([pred_latent[i, remain_num:remain_num+len(len_tar[i])] for i in range(bsz)], dim = 0)
        polyautoin = torch.cat([polytar[i, :len(len_tar[i])] for i in range(bsz)], dim = 0)

        out = self.automodel(latentautoin, polyautoin)

        len_tar = torch.cat([len_tar[i] for i in range(bsz)], dim = 0)
        loss_l1, loss_len = self.compute_loss(out, polyautoin, len_tar) 

        return loss_l1, loss_height, loss_len, out

    def infgen(self, poly, pos, h, postar):
        bsz, remain_num, _, _ = poly.shape
        assert bsz == 1
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, 60-remain_num, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar

        latent = self.forward_encoder(poly, pos, h)

        posall = torch.cat([pos, postarin], dim = 1)
        assert posall.shape[1]==self.max_build
        pred_latent, h_pred = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, remain_num:remain_num+postar_len]
        output_poly = []
        output_h = h_pred[0, remain_num:remain_num+postar_len]
        for k in range(latentin.shape[0]): 
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = out[0, -1, :2].unsqueeze(0).unsqueeze(0)

            for i in range(19):
                out = self.automodel(latentin_, inputpoly)
                if torch.sigmoid(out[0, -1, 2])>0.5:
                        # print('end:', torch.sigmoid(out[0, -1, 2]))
                    break
                inputpoly = torch.cat([inputpoly, out[0, -1, :2].unsqueeze(0).unsqueeze(0)], dim = 1)
            output_poly.append(inputpoly)

        return output_poly, output_h
    
    def genfirst(self, postar):
        postar_len = postar.shape[1]
        postarin = torch.zeros([1, 60, 2]).to(self.device)
        postarin[:, :postar_len, :] = postar

        x = self.mask_token

        for blk in self.blocks:
            x = blk(x)
        latent = self.norm(x)

        posall = postarin
        assert posall.shape[1]==self.max_build
        pred_latent, h_pred = self.forward_decoder(latent, posall) 

        latentin = pred_latent[0, 0:1]
        output_poly = []
        output_h = h_pred[0, 0:1] 
        for k in range(latentin.shape[0]): 
            latentin_ = latentin[k:k+1]
            out = self.automodel(latentin_, gen = True)
            inputpoly = out[0, -1, :2].unsqueeze(0).unsqueeze(0)

            for i in range(19):
                out = self.automodel(latentin_, inputpoly)
                if torch.sigmoid(out[0, -1, 2])>0.5:
                        # print('end:', torch.sigmoid(out[0, -1, 2]))
                    break
                inputpoly = torch.cat([inputpoly, out[0, -1, :2].unsqueeze(0).unsqueeze(0)], dim = 1)
            output_poly.append(inputpoly)

        return output_poly, output_h

