import torch
import torch.nn as nn
from attrdict import AttrDict

from .modules import CrossAttnEncoder, Decoder, PoolingEncoder
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from .neuraloperator.neuralop.models import FNO1d

class PNORCANP(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            enc_v_depth=4,
            enc_qk_depth=2,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()
        dim_out_chan = 32
        projection_chan = 32
        token_len = 10
        self.learnable_token = nn.Parameter(torch.randn(token_len))
        self.sfno = FNO1d(n_modes_height=12, in_channels=1, out_channels=dim_out_chan, hidden_channels=projection_chan, projection_channels=projection_chan, factorization='dense')
        self.dim_out_chan = dim_out_chan

        self.enc1 = CrossAttnEncoder(
                dim_x=self.dim_out_chan,
                dim_y=dim_y,
                dim_hid=dim_hid,
                v_depth=enc_v_depth,
                qk_depth=enc_qk_depth)

        self.enc2 = PoolingEncoder(
                dim_x=self.dim_out_chan,
                dim_y=dim_y,
                dim_hid=dim_hid,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=self.dim_out_chan,
                dim_y=dim_y,
                dim_enc=2*dim_hid,
                dim_hid=dim_hid,
                depth=dec_depth)

    def predict(self, xc, yc, xt, num_samples=None):
        x = torch.cat([xc,xt], dim=1)
        x = x.unsqueeze(-2)
        na_space = self.learnable_token.expand([*x.shape[:-1],len(self.learnable_token)])
        x = torch.cat([x, na_space], dim=-1)

        shape = x.shape
        x = x.reshape(-1,1,shape[-1])
        x = self.sfno(x=x)
        x = x.reshape(*shape[:-2],self.dim_out_chan,shape[-1])
        x = torch.mean(x, dim=-1)
        xc, xt = torch.split(x, split_size_or_sections=[xc.shape[1], xt.shape[1]], dim=1)

        # xc = xc.unsqueeze(-2)
        # shape = xc.shape
        # xc = xc.reshape(-1,1,shape[-1])
        # xc = self.sfno(x=xc)
        # xc = xc.reshape(*shape).squeeze(-2)
        # xc = torch.mean(xc, dim=-1, keepdim=True)
        theta1 = self.enc1(xc, yc, xt)
        theta2 = self.enc2(xc, yc)
        encoded = torch.cat([theta1,
            torch.stack([theta2]*xt.shape[-2], -2)], -1)
        return self.dec(encoded, xt)

    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        py = self.predict(batch.xc, batch.yc, batch.x)
        ll = py.log_prob(batch.y).sum(-1)

        if self.training:
            outs.loss = -ll.mean()
        else:
            num_ctx = batch.xc.shape[-2]
            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]

        return outs
