import torch
import torch.nn as nn
import os.path as osp

from .physics import GaussianAttention


class AttentionBlock(nn.Module):
    def __init__(self, num_gaussians, hidden_dim, heads=8, pos_dim=2, dropout=0.0, shape=[64, 64], *args, **kwargs):
        super().__init__()
        self.attn = GaussianAttention(num_gaussians, hidden_dim, heads, pos_dim, dropout, shape)
        self.norm1 = nn.LayerNorm(num_gaussians)
        self.mlp = nn.Sequential(
            nn.Linear(num_gaussians, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, num_gaussians)
        )
        self.norm2 = nn.LayerNorm(num_gaussians)

    def forward(self, z, mu, sigma, weight):
        z = z + self.norm1(self.attn(z, mu, sigma, weight))
        z = z + self.norm2(self.mlp(z))
        return z


class GPO(nn.Module):
    def __init__(self, in_dim, out_dim, gs_field=None, pos_dim=2,
                 hidden_dim=128, num_layers=4, num_gaussians=128, num_heads=4, shape=[64, 64]):
        super(GPO, self).__init__()
        self.__file__ = osp.abspath(__file__)
        
        self.gs_field = gs_field
        self.num_layers = num_layers
        self.pos_dim = pos_dim

        self.layers = nn.ModuleList([
            AttentionBlock(
                num_gaussians=num_gaussians,
                hidden_dim=hidden_dim,
                heads=num_heads,
                pos_dim=pos_dim,
                dropout=0.0,
                shape=shape,
            )
            for _ in range(num_layers)
        ])

    def freeze_gs(self):
        for param in self.gs_field.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        x_pos = x[..., :self.pos_dim]
        mu, sigma, weight = self.gs_field.encode(x)
        z = self.gs_field.compute_gaussian(x_pos, mu, sigma, weight=weight)
        for layer in self.layers:
            z = layer(z, mu, sigma, weight)
            out = self.gs_field.decode_z(z)
            mu, sigma, weight = self.gs_field.encode(torch.cat([x_pos, out], dim=-1))

        return out, (z, mu, sigma, weight)

    def evaluate(self, x):
        gs_list = []
        x_pos = x[..., :self.pos_dim]
        mu, sigma, weight = self.gs_field.encode(x)
        gs_list.append((x[..., -1],mu, sigma, weight))
        z = self.gs_field.compute_gaussian(x_pos, mu, sigma, weight=weight)
        for layer in self.layers:
            z = layer(z, mu, sigma, weight)
            out = self.gs_field.decode_z(z)
            mu, sigma, weight = self.gs_field.encode(torch.cat([x_pos, out], dim=-1))
            gs_list.append((out, mu, sigma, weight))

        return out, gs_list 
