import torch
import torch.nn as nn
import os
import sys
import numpy as np
import copy
import warnings
warnings.filterwarnings("ignore")
from models.modules.encoder_module import Custom_block
from models.modules.attention_module import AGF_layer
from models.rational.torch import Rational

from utils.config import *

args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


# Transformer model    
class TransformerModel(nn.Module):
    def __init__(self, n_in, n_out, args):
        super().__init__()
        
        self.ninp = args.nips
        self.nhead = args.nhead
        self.nhid = args.nhid
        self.nlayers = args.nlayers
        self.dropout = args.dropout
        self.num_freqs = args.num_freqs
        self.AGF_depth = args.AGF_depth
        self.alpha = args.alpha
        self.beta = args.beta
        self.fixI = args.fixI
        self.is_grid = args.is_grid
        self.xgrid =args.xgrid
        
        self.n_out = n_out
        self.n_in = n_in

        block = AGF_layer(d_model=self.ninp*2, nhead=self.nhead, dim_feedforward=self.nhid, dropout=self.dropout, activation='gelu', 
                          K=self.AGF_depth, alpha=self.alpha, beta=self.beta, fixI=self.fixI, device=device)
        print("PDE-PFN")
        print(" - En/Decoder        : Fourier feature embedding in domain and solution encoders, Learnable activation")
        print(" - Transformer block : Use Pre-LN, Devided attention, Parameterized SVD Attention Seperated for each head")
            
        self.transformer_encoder = Custom_block(block, self.nlayers)
        self.fourier_dim = 2 if self.is_grid else 1+2*self.num_freqs
        self.domain_encoder1 = nn.Sequential(nn.Linear((self.n_in-self.n_out)*self.fourier_dim, self.nhid), Rational(), nn.Linear(self.nhid, self.ninp))
        self.domain_encoder2 = nn.Sequential(nn.Linear((self.n_in-self.n_out)*(1+2*self.num_freqs), self.nhid), Rational(), nn.Linear(self.nhid, 2*self.ninp))
        self.solution_encoder = nn.Sequential(nn.Linear(self.n_out*self.fourier_dim, self.nhid), Rational(), nn.Linear(self.nhid, self.ninp))
        self.decoder = nn.Sequential(nn.Linear(self.ninp*2, self.nhid), Rational(), nn.Linear(self.nhid, self.n_out))
    
    def fourier_mapping(self, x, is_grid=False):
        if is_grid:
            B,L,C = x.shape
            x = x.reshape(B,self.xgrid, self.xgrid, C)
            fourier = torch.fft.fftshift(x, dim=[1,2])
            return torch.cat([x, fourier], dim=-1).reshape(B,-1,2*C)   # Shape: [B, T, D + D]
        elif not is_grid and self.num_freqs > 0:
            freqs = 2.0 * torch.pi/self.num_freqs * torch.arange(1, self.num_freqs+1).float().to(x.device) # [2pi/5, 4pi/5, 6pi/5, ...] : 5
            x_proj = x.unsqueeze(-1) * freqs                                    # Shape: [B, T, D, F]
            fourier = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1).flatten(2,3) # Shape: [B, T, D, 2F]
            return torch.cat([x, fourier], dim=-1)                              # Shape: [B, T, D + 2*D*F]
        else:
            return x
    
    def forward(self, src, single_eval_pos):
        
        # UNKNOWN = 1.0
        src = src.clone().to(device).float()
        
        # Divide domain and solution.
        src_domain1 = src[:, :single_eval_pos, :self.n_in-self.n_out]
        src_qry = src[:, single_eval_pos:, :self.n_in-self.n_out]
        src_solution = src[:, :single_eval_pos, self.n_in-self.n_out:]
        
        # Encode and fuse sources.
        src_domain1 = self.domain_encoder1(self.fourier_mapping(src_domain1, self.is_grid))
        src_qry = self.domain_encoder2(self.fourier_mapping(src_qry))
        src_solution = self.solution_encoder(self.fourier_mapping(src_solution, self.is_grid))
        
        src_ctx = torch.concatenate((src_domain1, src_solution), dim = -1)
        src = torch.concatenate((src_ctx, src_qry), dim = 1)
        # Encode and decode the source.
        output, ortho_loss_list = self.transformer_encoder(src, single_eval_pos)
        output = self.decoder(output)

        return output[:, single_eval_pos:, :], ortho_loss_list
    
    
    
    
    