import torch.nn.functional as F
import torch
import torch.nn as nn

import sys
sys.path.append('../share')
sys.path.append('../model')
from models_pospred import MAGECityPosition


class MAGECityPolyGenProbing(nn.Module):
    def __init__(self,  device = 'cuda', max_build = 60, hidden = 512, num_class = 15, fine_tune = False):
        super().__init__()
        # --------------------------------------------------------------------------
        self.mae = MAGECityPosition(embed_dim=256, depth=6, num_heads=8, 
                                 decoder_embed_dim=16, decoder_depth=3,   
                                 decoder_num_heads=8, discre = 50, patch_size = 5, patch_num = 10, 
                                 device = device, ablation = False, patchify = True)
        if not fine_tune:
            self.mae.eval()
        
        self.decoder = nn.Linear(hidden*max_build,  hidden, bias=True)
        self.decoder_embed = nn.Linear(hidden, num_class, bias=True)
        self.crossentropyloss = nn.CrossEntropyLoss()

        self.initialize_weights()

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x, pos, tarclass):
        bsz = x.shape[0]

        latent = self.mae.forward_encoder(x, pos)
        x = self.decoder_embed(F.relu(self.decoder(latent[:, 1:].flatten(1,2))))
        
        loss = self.crossentropyloss(x, tarclass)

        acc = 1 - torch.count_nonzero(torch.argmax(x, dim = -1) - tarclass)/bsz
        
        return loss, acc

    