import torch
import torch.nn as nn
import os
import sys

from torch.nn import TransformerEncoder, TransformerEncoderLayer
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config import *

args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Transformer model
class TransformerModel(nn.Module):
    def __init__(self, n_out, ninp, nhead, nhid, nlayers, dropout=0.5, decoder=None):
        super().__init__()
        
        self.model_type = 'Transformer'
        encoder_layers = TransformerEncoderLayer(ninp*2, nhead, nhid, dropout, activation='gelu', batch_first = True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.ninp = ninp
        self.domain_encoder = nn.Sequential(nn.Linear(2, nhid), nn.GELU(), nn.Linear(nhid, ninp))
        self.solution_encoder = nn.Sequential(nn.Linear(1, nhid), nn.GELU(), nn.Linear(nhid, ninp))
        self.decoder = decoder(ninp*2, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp*2, nhid), nn.GELU(), nn.Linear(nhid, n_out))
    
    # Generate target source mask.
    @staticmethod
    def generate_D_q_matrix(size, query_size):
        obs_size = size - query_size
        mask = torch.ones(size, size)
        mask[:, obs_size:] = 0
        mask.fill_diagonal_(1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)
        return mask
    
    def forward(self, src, single_eval_pos):
        
        UNKNOWN = 1
        
        src = src.clone().to(device).float()
        src_mask = self.generate_D_q_matrix(src.size(1), src.size(1) - single_eval_pos).to(device)
        
        
        # Divide domain and solution.
        src_domain = src[:, :, :2]
        src_solution = src[:, :, 2:]
        
        # Insert "UNKNOWN" token.
        src_solution[:, single_eval_pos:] = UNKNOWN
        
        # Encode and fuse sources.
        src_domain = self.domain_encoder(src_domain)
        src_solution = self.solution_encoder(src_solution)
        
        src = torch.concatenate((src_domain, src_solution), dim = 2)
        
        # Encode and decode the source.
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)

        return output[:, single_eval_pos:, :]