import torch
import torch.nn as nn
from attrdict import AttrDict

from .modules import PoolingEncoder, Decoder
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from .neuraloperator.neuralop.models import FNO1d

class NOCNP(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()

        self.sfno = FNO1d(n_modes_height=16, in_channels=1, out_channels=1, hidden_channels=16, projection_channels=32, factorization='dense')

        self.enc1 = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.enc2 = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_enc=2*dim_hid,
                dim_hid=dim_hid,
                depth=dec_depth)

    def predict(self, xc, yc, xt, num_samples=None):
        x = torch.cat([xc,xt], dim=1)
        x = x.unsqueeze(-2)
        shape = x.shape
        x = x.reshape(-1,1,shape[-1])
        x = self.sfno(x=x)
        x = x.reshape(*shape).squeeze(-2)
        x = torch.mean(x, dim=-1, keepdim=True)
        xc, xt = torch.split(x, split_size_or_sections=[xc.shape[1], xt.shape[1]], dim=1)

        # xc = xc.unsqueeze(-2)
        # shape = xc.shape
        # xc = xc.reshape(-1,1,shape[-1])
        # xc = self.sfno(x=xc)
        # xc = xc.reshape(*shape).squeeze(-2)
        # xc = torch.mean(xc, dim=-1, keepdim=True)
        encoded = torch.cat([self.enc1(xc, yc), self.enc2(xc, yc)], -1)
        encoded = torch.stack([encoded]*xt.shape[-2], -2)
        return self.dec(encoded, xt)

    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        py = self.predict(batch.xc, batch.yc, batch.x)
        ll = py.log_prob(batch.y).sum(-1)

        if self.training:
            outs.loss = -ll.mean()
        else:
            num_ctx = batch.xc.shape[-2]
            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]

        return outs
