import torch
import torch.nn as nn
from attrdict import AttrDict
import math

from .modules import CrossAttnEncoder, Decoder, PoolingEncoder
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from .neuraloperator.neuralop.models import FNO1d

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        
        pe = torch.permute(torch.cat([self.pe[:x.size(-1)-1, :], self.pe[-1,:].unsqueeze(0)]),(1,0))
        pe = pe.expand([*x.shape[:-2],*pe.shape])

        x = torch.cat([x, pe], dim=-2)
        return x

class PPNORCANP(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            enc_v_depth=4,
            enc_qk_depth=2,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()
        dim_out_chan = 32
        projection_chan = 32
        token_len = 2
        max_seq_len = 101
        self.learnable_token = nn.Parameter(torch.randn(token_len))
        self.positional_encoding = PositionalEncoding(10, max_seq_len)
        self.sfno = FNO1d(n_modes_height=8, in_channels=11, out_channels=dim_out_chan, hidden_channels=projection_chan, projection_channels=projection_chan, factorization='dense')
        self.dim_out_chan = dim_out_chan

        self.enc1 = CrossAttnEncoder(
                dim_x=self.dim_out_chan,
                dim_y=self.dim_out_chan,
                dim_hid=dim_hid,
                v_depth=enc_v_depth,
                qk_depth=enc_qk_depth)

        # self.enc2 = PoolingEncoder(
        #         dim_x=self.dim_out_chan,
        #         dim_y=dim_y,
        #         dim_hid=dim_hid,
        #         pre_depth=enc_pre_depth,
        #         post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=self.dim_out_chan,
                dim_y=dim_y,
                dim_enc=dim_hid,
                dim_hid=dim_hid,
                depth=dec_depth)

    def predict(self, xc, yc, xt, num_samples=None):
        yt_fake = torch.zeros([*xt.shape[:-1],1]).to(xt.device)
        target_xy = torch.cat([xt, yt_fake], dim=-1) # [B, num_target, dim_x+1]
        context_xy = torch.cat([xc, yc], dim=-1) # [B, num_context, dim_x+1]

        xy = torch.cat([target_xy,context_xy], dim=1)
        xy = xy.unsqueeze(-2)
        na_space = self.learnable_token.expand([*xy.shape[:-1],len(self.learnable_token)])
        xy = torch.cat([na_space, xy], dim=-1)

        shape = xy.shape
        xy = xy.reshape(-1,1,shape[-1])

        xy = self.positional_encoding(xy)

        xy = self.sfno(x=xy)
        xy = xy.reshape(*shape[:-2],self.dim_out_chan,shape[-1])

        x = torch.mean(xy[...,:-1], dim=-1)
        y = xy[...,-1]
        xc, xt = torch.split(x, split_size_or_sections=[xc.shape[1], xt.shape[1]], dim=1)
        yc, pseudoyt = torch.split(y, split_size_or_sections=[xc.shape[1], xt.shape[1]], dim=1)

        # xc = xc.unsqueeze(-2)
        # shape = xc.shape
        # xc = xc.reshape(-1,1,shape[-1])
        # xc = self.sfno(x=xc)
        # xc = xc.reshape(*shape).squeeze(-2)
        # xc = torch.mean(xc, dim=-1, keepdim=True)
        theta1 = self.enc1(xc, yc, xt)
        # theta2 = self.enc2(xc, yc)
        # encoded = torch.cat([theta1,
        #     torch.stack([theta2]*xt.shape[-2], -2)], -1)
        encoded = theta1
        return self.dec(encoded, xt)

    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        py = self.predict(batch.xc, batch.yc, batch.x)
        ll = py.log_prob(batch.y).sum(-1)

        if self.training:
            outs.loss = -ll.mean()
        else:
            num_ctx = batch.xc.shape[-2]
            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]

        return outs
