import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent))

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from vcnef_2d import VCNeFModel


class VCNeF2d(nn.Module):
    def __init__(self, num_channels, env_dim, d_model, n_modulation_blocks, condition_on_pde_param, x_span, y_span):
        super(VCNeF2d, self).__init__()
        self.x_span = x_span
        self.y_span = y_span

        self.model = VCNeFModel(num_channels=num_channels,
                                condition_on_pde_param=condition_on_pde_param, pde_param_dim=env_dim,
                                d_model=d_model, n_heads=4,
                                n_transformer_blocks=1, n_modulation_blocks=n_modulation_blocks)

    def forward(self, x, pde_param, t):  # x: (B, C, Nx, Ny), pde_param: (B, env_dim), t: (1, )
        x = rearrange(x, 'b c h w -> b h w c')  # (B, Nx, Ny, C)
        grid = self.get_grid(x.shape, x.device)  # (B, Nx, Ny, 2)
        t = t.unsqueeze(0).repeat(x.shape[0], 1)  # (B, 1)

        x_hat = self.model(x, grid, pde_param, t)  # (B, Nx, Ny, 1, C)
        x_out = x_hat.squeeze(3)  # (B, Nx, Ny, C)

        return x_out

    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.x_span, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.y_span, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)