from abc import ABC

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from .Kernel import GaussianKernel
import matplotlib.pyplot as plt

from data import denormalize, normalize
from utils import load_data


class MessagePassingGNN(nn.Module):
    """
    Message Passing Graph Neural Network for encoding/decoding in embedding space
    """
    def __init__(self, args, input_particle_dim=None, input_relation_dim=None, output_dim=None, 
                 action=True, tanh=False, residual=False, use_gpu=False):
        super(MessagePassingGNN, self).__init__()
        
        self.args = args
        self.action = action
        
        if input_particle_dim is None:
            input_particle_dim = args.attr_dim + args.state_dim
            input_particle_dim += args.action_dim if action else 0
            
        if input_relation_dim is None:
            input_relation_dim = args.relation_dim + args.state_dim
            
        if output_dim is None:
            output_dim = args.state_dim
            
        hidden_dim = args.nf_effect
        self.residual = residual
        self.use_gpu = use_gpu
        
        # Node encoder (transforms node features)
        self.node_encoder = nn.Sequential(
            nn.Linear(input_particle_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Edge encoder (transforms edge features)
        self.edge_encoder = nn.Sequential(
            nn.Linear(input_relation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Message function (combines node and edge features)
        self.message_fn = nn.Sequential(
            nn.Linear(2 * hidden_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Update function (updates node representations)
        self.update_fn = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Output function (final prediction from node features)
        output_layers = [
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        ]
        
        if tanh:
            output_layers.append(nn.Tanh())
            
        self.output_fn = nn.Sequential(*output_layers)
        
    def forward(self, attrs, states, actions, rel_attrs, pstep):
        """
        Message passing GNN forward pass
        
        Args:
            attrs: B x N x attr_dim (node attributes)
            states: B x N x state_dim (node states)
            actions: B x N x action_dim (node actions)
            rel_attrs: B x N x N x relation_dim (edge attributes)
            pstep: number of message passing steps
            
        Returns:
            node_predictions: B x N x output_dim
        """
        B, N = attrs.size(0), attrs.size(1)
        
        # Prepare node features
        node_inputs = [attrs, states]
        if self.action and actions is not None:
            node_inputs.append(actions)
        node_features = torch.cat(node_inputs, dim=2)
        
        # Initial node encoding
        node_embeddings = self.node_encoder(node_features.reshape(B*N, -1)).reshape(B, N, -1)
        
        # Prepare edge features: includes relation attributes and state differences
        rel_states = states[:, :, None, :] - states[:, None, :, :]
        # Include sender and receiver attributes in edge features
        receiver_attr = attrs[:, :, None, :].repeat(1, 1, N, 1)
        sender_attr = attrs[:, None, :, :].repeat(1, N, 1, 1)
        edge_features = torch.cat([rel_attrs, rel_states, receiver_attr, sender_attr], dim=3)
        edge_embeddings = self.edge_encoder(edge_features.reshape(B*N*N, -1)).reshape(B, N, N, -1)
        
        # Message passing steps
        for _ in range(pstep):
            # Create messages
            receiver_code = node_embeddings[:, :, None, :].repeat(1, 1, N, 1)
            sender_code = node_embeddings[:, None, :, :].repeat(1, N, 1, 1)
            message_inputs = torch.cat([edge_embeddings, receiver_code, sender_code], 3)
            messages = self.message_fn(message_inputs.reshape(B*N*N, -1)).reshape(B, N, N, -1)
            
            # Aggregate messages (sum over sender dimension)
            agg_messages = messages.sum(2)
            
            # Update node representations
            update_inputs = torch.cat([node_embeddings, agg_messages], 2)
            node_updates = self.update_fn(update_inputs.reshape(B*N, -1)).reshape(B, N, -1)
            
            # Apply residual connection if specified
            if self.residual:
                node_embeddings = node_embeddings + node_updates
            else:
                node_embeddings = node_updates
        
        # Generate output predictions
        node_predictions = self.output_fn(node_embeddings.reshape(B*N, -1)).reshape(B, N, -1)
        
        return node_predictions


class GaussianKernel(nn.Module):
    """
    Gaussian kernel for state similarity weighting
    """
    def __init__(self, bandwidth=1.0, adaptive_bandwidth=False, scaling_factor=0.4):
        super(GaussianKernel, self).__init__()
        self.bandwidth = bandwidth
        self.adaptive_bandwidth = adaptive_bandwidth
        self.scaling_factor = scaling_factor  # For Scott's rule scaling
        
    def forward(self, x, y):
        """
        Compute Gaussian kernel between batches of vectors
        
        Args:
            x: [..., N, D] tensor
            y: [..., M, D] tensor
            
        Returns:
            [..., N, M] tensor of kernel values
        """
        # Compute squared distances efficiently
        x_norm = torch.sum(x**2, dim=-1, keepdim=True)  # [..., N, 1]
        y_norm = torch.sum(y**2, dim=-1, keepdim=True)  # [..., M, 1]
        
        # Expand for broadcasting
        x_y = torch.matmul(x, y.transpose(-2, -1))  # [..., N, M]
        dist_sq = x_norm + y_norm.transpose(-2, -1) - 2 * x_y  # [..., N, M]
        
        # Ensure numerical stability
        dist_sq = torch.clamp(dist_sq, min=0.0)
        
        # Apply kernel with current bandwidth
        return torch.exp(-0.5 * dist_sq / self.bandwidth)


class ControllableEmbedding(nn.Module, ABC):
    def __init__(self, args, residual=False, use_gpu=False):
        super(ControllableEmbedding, self).__init__()

        self.args = args

        self.stat = load_data(['attrs', 'states', 'actions'], args.stat_path)

        g_dim = args.g_dim

        self.nf_effect = args.nf_effect

        self.use_gpu = use_gpu
        self.residual = residual

        ''' state '''
        # we should not include action in state encoder
        input_particle_dim = args.attr_dim + args.state_dim
        input_relation_dim = args.state_dim + args.relation_dim + args.attr_dim * 2

        self.state_encoder = MessagePassingGNN(
            args, input_particle_dim=input_particle_dim, input_relation_dim=input_relation_dim,
            output_dim=g_dim, action=False, tanh=True,
            residual=residual, use_gpu=use_gpu)

        # the state for decoding phase is replaced with code of g_dim
        input_particle_dim = args.attr_dim + args.g_dim
        input_relation_dim = args.g_dim + args.relation_dim + args.attr_dim * 2

        self.state_decoder = MessagePassingGNN(
            args, input_particle_dim=input_particle_dim, input_relation_dim=input_relation_dim,
            output_dim=args.state_dim, action=False, tanh=False,
            residual=residual, use_gpu=use_gpu)

        ''' dynamical system coefficient: A and B '''
        self.A = None
        self.B = None
        
        # Setup system identification and simulation methods based on fit_type
        if args.fit_type == 'Hom':
            self.system_identify = self.fit_Hom
            self.simulate = self.rollout_Hom
            self.step = self.linear_forward_Hom
        elif args.fit_type == 'dense':
            self.system_identify = self.fit_dense
            self.simulate = self.rollout_dense
            self.step = self.linear_forward_dense
        elif args.fit_type == 'diagonal':
            self.system_identify = self.fit_diagonal
            self.simulate = self.rollout_diagonal
            self.step = self.linear_forward_diagonal
        elif args.fit_type == 'Gaussian_reweight':
            self.system_identify = self.fit_with_Gaussian_reweight
            self.simulate = self.rollout_Hom
            self.step = self.linear_forward_Hom
            
            # Configure Gaussian kernel with adaptive bandwidth support
            self.kernel = GaussianKernel(
                bandwidth=getattr(args, 'bandwidth', 1.0),
                adaptive_bandwidth=getattr(args, 'adaptive_bandwidth', True),
                scaling_factor=getattr(args, 'scaling_factor', 0.4)
            )

    def to_s(self, attrs, gcodes, rel_attrs, pstep):
        """ state decoder """

        if self.args.env in ['Soft', 'Swim']:
            states = self.state_decoder(attrs=attrs, states=gcodes, actions=None, rel_attrs=rel_attrs, pstep=pstep)
            return regularize_state_Soft(states, rel_attrs, self.stat)

        return self.state_decoder(attrs=attrs, states=gcodes, actions=None, rel_attrs=rel_attrs, pstep=pstep)

    def to_g(self, attrs, states, rel_attrs, pstep):
        """ state encoder """
        return self.state_encoder(attrs=attrs, states=states, actions=None, rel_attrs=rel_attrs, pstep=pstep)

    @staticmethod
    def get_aug(G, rel_attrs):
        """
        :param G: B x T x N x D
        :param rel_attrs:  B x N x N x R
        :return:
        """
        B, T, N, D = G.size()
        R = rel_attrs.size(-1)

        sumG_list = []
        for i in range(R):
            ''' B x T x N x N '''
            adj = rel_attrs[:, :, :, i][:, None, :, :].repeat(1, T, 1, 1)
            sumG = torch.bmm(
                adj.reshape(B * T, N, N),
                G.reshape(B * T, N, D)
            ).reshape(B, T, N, D)
            sumG_list.append(sumG)

        return torch.cat(sumG_list, -1)

    def fit_Hom(self, G, H, U, I_factor, rel_attrs):
        """
        shared relation weight
        :param G:
        :param H:
        :param U:
        :param I_factor:
        :param rel_attrs:
        :return:
        """
        bs, T, N, D = G.size()

        ''' B x T x N x R D '''
        aug_G = self.get_aug(G, rel_attrs=rel_attrs)
        ''' B x T x N x R a_dim'''
        aug_U = self.get_aug(U, rel_attrs=rel_attrs)

        ''' B x (R D) x D '''
        R = rel_attrs.size(-1)
        A_dim = D * R
        a_dim = U.size(3)
        A_B = torch.zeros(bs, A_dim + a_dim * R, D)
        if aug_G.is_cuda:
            A_B = A_B.cuda()

        fit_err = 0.
        for i in range(N):
            tmp_G_U = torch.cat([aug_G[:, :, i], aug_U[:, :, i]], 2)

            ''' B x (RD + R * a_dim) x D'''
            tmp_A_B = torch.bmm(
                self.batch_pinv(tmp_G_U, I_factor),
                H[:, :, i]
            )
            A_B += tmp_A_B / N

            ''' B x T x D'''
            tmp_fit_err = H[:, :, i] - torch.bmm(tmp_G_U, tmp_A_B)
            fit_err += torch.sqrt((tmp_fit_err ** 2).mean()) / N

        self.A = A_B[:, :A_dim]
        self.B = A_B[:, A_dim:]

        return self.A, self.B, fit_err

    def linear_forward_Hom(self, g, u, rel_attrs):
        """
        :param g: B x N x D
        :param u: B x N x a_dim
        :param rel_attrs: B x N x N x R
        :return:
        """
        ''' B x N x R D '''
        aug_g = self.get_aug(G=g[:, None, :, :], rel_attrs=rel_attrs)[:, 0]
        ''' B x N x R a_dim'''
        aug_u = self.get_aug(G=u[:, None, :, :], rel_attrs=rel_attrs)[:, 0]

        new_g = torch.bmm(aug_g, self.A) + torch.bmm(aug_u, self.B)
        return new_g

    def rollout_Hom(self, g, u_seq, T, rel_attrs):
        """
        :param g: B x N x D
        :param u_seq: B x T x N x a_dim
        :param rel_attrs: B x N x N x R
        :param T:
        :return:
        """
        g_list = []
        for t in range(T):
            g = self.linear_forward_Hom(g, u_seq[:, t], rel_attrs)
            g_list.append(g[:, None, :, :])
        return torch.cat(g_list, 1)

    def fit_dense(self, G, H, U, I_factor, rel_attrs=None):
        """
        :param G: B x T x N x D
        :param H: B x T x N x D
        :param U: B x T x N x a_dim
        :param I_factor: scalor
        :return: A, B
        s.t.
        H = catG @ A + catU @ B
        """
        bs, T, N, D = G.size()
        G = G.reshape(bs, T, -1)
        H = H.reshape(bs, T, -1)
        U = U.reshape(bs, T, -1)

        G_U = torch.cat([G, U], 2)
        A_B = torch.bmm(
            self.batch_pinv(G_U, I_factor),
            H
        )
        self.A = A_B[:, :N * D]
        self.B = A_B[:, N * D:]

        fit_err = H - torch.bmm(G_U, A_B)
        fit_err = torch.sqrt((fit_err ** 2).mean())

        return self.A, self.B, fit_err

    def linear_forward_dense(self, g, u, rel_attrs=None):
        B, N, D = g.size()
        a_dim = u.size(-1)
        g = g.reshape(B, 1, N * D)
        u = u.reshape(B, 1, N * a_dim)
        new_g = torch.bmm(g, self.A) + torch.bmm(u, self.B)
        return new_g.reshape(B, N, D)

    def rollout_dense(self, g, u_seq, T, rel_attrs=None):
        g_list = []
        for t in range(T):
            g = self.linear_forward_dense(g, u_seq[:, t])
            g_list.append(g[:, None, :, :])
        return torch.cat(g_list, 1)

    def fit_diagonal(self, G, H, U, I_factor, rel_attrs=None):
        bs, T, N, D = G.size()
        a_dim = U.size(3)

        G_U = torch.cat([G, U], 3)

        '''B x (D + a_dim) x D'''
        A_B = torch.bmm(
            self.batch_pinv(G_U.reshape(bs, T * N, D + a_dim), I_factor),
            H.reshape(bs, T * N, D)
        )
        self.A = A_B[:, :D]
        self.B = A_B[:, D:]

        fit_err = H.reshape(bs, T * N, D) - torch.bmm(G_U.reshape(bs, T * N, D + a_dim), A_B)
        fit_err = torch.sqrt((fit_err ** 2).mean())

        return self.A, self.B, fit_err

    def linear_forward_diagonal(self, g, u, rel_attrs=None):
        new_g = torch.bmm(g, self.A) + torch.bmm(u, self.B)
        return new_g

    def rollout_diagonal(self, g, u_seq, T, rel_attrs=None):
        g_list = []
        for t in range(T):
            g = self.linear_forward_diagonal(g, u_seq[:, t])
            g_list.append(g[:, None, :, :])
        return torch.cat(g_list, 1)

    def fit_with_Gaussian_reweight(self, G, H, U, I_factor, rel_attrs):
        """
        Enhanced Gaussian reweighting method with adaptive bandwidth and improved numerical stability
        
        This method builds a linear model by weighting neighboring states based on their
        similarity in the latent space using a mean field approximation. The approach adapts to the local
        geometry of the state space, producing more accurate system identification.
        
        Args:
            G: B x T x N x D tensor (current latent states)
            H: B x T x N x D tensor (next latent states)
            U: B x T x N x a_dim tensor (control inputs)
            I_factor: float (regularization factor for pseudo-inverse)
            rel_attrs: B x N x N x R tensor (relation attributes)
            
        Returns:
            A: B x RD x D tensor (state transition matrices)
            B: B x R*a_dim x D tensor (control matrices)
            fit_err: float (fitting error)
        """
        B, T, N, D = G.size()
        
        # Calculate pairwise distances in latent space more efficiently
        # Store G_reshaped once instead of creating two separate tensors
        G_reshaped = G.view(B, T*N, D)
        
        # Compute squared distances using batch matrix multiplication
        # (a-b)^2 = a^2 - 2ab + b^2
        G_norm = torch.sum(G_reshaped**2, dim=2, keepdim=True)  # B x TN x 1
        G_dot = torch.bmm(G_reshaped, G_reshaped.transpose(1, 2))  # B x TN x TN
        squared_dist = G_norm + G_norm.transpose(1, 2) - 2 * G_dot  # B x TN x TN
        
        # Reshape to original dimensions
        squared_dist = squared_dist.view(B, T, N, T, N).permute(0, 1, 3, 2, 4).reshape(B, T, T, N, N)
        
        # Handle numerical instabilities (ensure distances are non-negative)
        squared_dist = torch.clamp(squared_dist, min=0.0)
        
        # Determine adaptive bandwidth if supported
        if self.kernel.adaptive_bandwidth:
            # Compute median distance for meaningful samples only
            median_dist = torch.median(squared_dist.view(B, T*T*N*N), dim=1)[0]
            # Use scaled median as bandwidth (Scott's rule inspired scaling)
            bandwidth = median_dist.view(B, 1, 1, 1, 1) * (self.kernel.scaling_factor * (T*N)**(-0.2))
        else:
            # Use fixed bandwidth from kernel configuration
            bandwidth = self.kernel.bandwidth
        
        # Apply Gaussian kernel with numerically stable computation
        kernel_weights = torch.exp(-0.5 * squared_dist / bandwidth)
        
        # Normalize weights with improved numerical stability
        epsilon = 1e-10  # Small constant to prevent division by zero
        weight_sum = kernel_weights.sum(dim=(1, 3, 4), keepdim=True) + epsilon
        kernel_weights = kernel_weights / weight_sum
        
        # Get augmented state and control representations using relation structure
        aug_G = self.get_aug(G, rel_attrs=rel_attrs)  # B x T x N x RD
        aug_U = self.get_aug(U, rel_attrs=rel_attrs)  # B x T x N x R*a_dim
        
        # Initialize system matrices
        R = rel_attrs.size(-1)
        A_dim = D * R
        a_dim = U.size(3)
        A_B = torch.zeros(B, A_dim + a_dim * R, D)
        if G.is_cuda:
            A_B = A_B.cuda()
        
        # Track fitting error
        fit_err = 0.0
        
        # Identify system matrices for each node
        for i in range(N):
            # Concatenate state and control for this node
            node_data = torch.cat([aug_G[:, :, i], aug_U[:, :, i]], dim=2)  # B x T x (RD + R*a_dim)
            
            # Build weighted regression problem with efficient batching
            weighted_data_list = []
            weighted_targets_list = []
            
            # Reshape kernel weights for this node for efficient computation
            node_weights = kernel_weights[:, :, :, i, :].reshape(B, T, T, 1, N)
            
            # For each source timestep
            for t in range(T):
                # Compute aggregate weight for this timestep across all targets
                # This captures the importance of this timestep in the regression
                time_weights = node_weights[:, t].sum(dim=(2, 3))  # B x T
                
                # Apply weights to data and targets
                weighted_data = node_data[:, t:t+1] * time_weights.unsqueeze(2)  # B x 1 x (RD + R*a_dim)
                weighted_targets = H[:, t:t+1, i] * time_weights.unsqueeze(2)  # B x 1 x D
                
                weighted_data_list.append(weighted_data)
                weighted_targets_list.append(weighted_targets)
            
            # Combine all weighted samples
            weighted_data = torch.cat(weighted_data_list, dim=1)  # B x T x (RD + R*a_dim)
            weighted_targets = torch.cat(weighted_targets_list, dim=1)  # B x T x D
            
            # Solve weighted least squares problem with regularized pseudo-inverse
            node_AB = torch.bmm(
                self.batch_pinv(weighted_data, I_factor + 1e-8),  # Add small epsilon for stability
                weighted_targets
            )
            
            # Accumulate node contribution to system matrices
            A_B += node_AB / N
            
            # Compute fitting error for this node
            node_pred = torch.bmm(weighted_data, node_AB)
            node_err = weighted_targets - node_pred
            fit_err += torch.sqrt((node_err**2).mean()) / N
        
        # Extract A and B matrices from combined solution
        self.A = A_B[:, :A_dim]  # State transition matrix
        self.B = A_B[:, A_dim:]  # Control matrix
        
        return self.A, self.B, fit_err

    @staticmethod
    def batch_pinv(x, I_factor):
        """
        :param x: B x N x D (N > D)
        :param I_factor:
        :return:
        """
        B, N, D = x.size()

        if N < D:
            x = torch.transpose(x, 1, 2)
            N, D = D, N
            trans = True
        else:
            trans = False

        x_t = torch.transpose(x, 1, 2)

        use_gpu = torch.cuda.is_available()
        I = torch.eye(D)[None, :, :].repeat(B, 1, 1)
        if use_gpu:
            I = I.cuda()

        x_pinv = torch.bmm(
            torch.inverse(torch.bmm(x_t, x) + I_factor * I),
            x_t
        )

        if trans:
            x_pinv = torch.transpose(x_pinv, 1, 2)

        return x_pinv


def regularize_state_Soft(states, rel_attrs, stat):
    """
    :param states: B x N x state_dim
    :param rel_attrs: B x N x N x relation_dim
    :param stat: [xxx]
    :return new states: B x N x state_dim
    """
    states_denorm = denormalize([states], [stat[1]], var=True)[0]
    states_denorm_acc = denormalize([states.clone()], [stat[1]], var=True)[0]

    rel_attrs = rel_attrs[0]

    rel_attrs_np = rel_attrs.detach().cpu().numpy()

    def get_rel_id(x):
        return np.where(x > 0)[0][0]

    B, N, state_dim = states.size()
    count = Variable(torch.FloatTensor(np.zeros((1, N, 1, 8))).to(states.device))

    for i in range(N):
        for j in range(N):

            if i == j:
                assert get_rel_id(rel_attrs_np[i, j]) % 9 == 0  # rel_attrs[i, j, 0] == 1
                count[:, i, :, :] += 1
                continue

            assert torch.sum(rel_attrs[i, j]) <= 1

            if torch.sum(rel_attrs[i, j]) == 0:
                continue

            if get_rel_id(rel_attrs_np[i, j]) % 9 == 1:  # rel_attrs[i, j, 1] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 2  # rel_attrs[j, i, 2] == 1
                x0 = 1; y0 = 3
                x1 = 0; y1 = 2
                idx = 1
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 2:  # rel_attrs[i, j, 2] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 1  # rel_attrs[j, i, 1] == 1
                x0 = 3; y0 = 1
                x1 = 2; y1 = 0
                idx = 2
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 3:  # rel_attrs[i, j, 3] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 4  # rel_attrs[j, i, 4] == 1
                x0 = 0; y0 = 1
                x1 = 2; y1 = 3
                idx = 3
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 4:  # rel_attrs[i, j, 4] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 3  # rel_attrs[j, i, 3] == 1
                x0 = 1; y0 = 0
                x1 = 3; y1 = 2
                idx = 4
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 5:  # rel_attrs[i, j, 5] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 8  # rel_attrs[j, i, 8] == 1
                x = 0; y = 3
                idx = 5
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 8:  # rel_attrs[i, j, 8] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 5  # rel_attrs[j, i, 5] == 1
                x = 3; y = 0
                idx = 8
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 6:  # rel_attrs[i, j, 6] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 7  # rel_attrs[j, i, 7] == 1
                x = 1; y = 2
                idx = 6
            elif get_rel_id(rel_attrs_np[i, j]) % 9 == 7:  # rel_attrs[i, j, 7] == 1:
                assert get_rel_id(rel_attrs_np[j, i]) % 9 == 6  # rel_attrs[j, i, 6] == 1
                x = 2; y = 1
                idx = 7
            else:
                AssertionError("Unknown rel_attr %f" % rel_attrs[i, j])

            if idx < 5:
                # if connect by two points
                x0 *= 2; y0 *= 2
                x1 *= 2; y1 *= 2
                count[:, i, :, x0:x0 + 2] += 1
                count[:, i, :, x1:x1 + 2] += 1
                states_denorm_acc[:, i, x0:x0 + 2] += states_denorm[:, j, y0:y0 + 2]
                states_denorm_acc[:, i, x0 + 8:x0 + 10] += states_denorm[:, j, y0 + 8:y0 + 10]
                states_denorm_acc[:, i, x1:x1 + 2] += states_denorm[:, j, y1:y1 + 2]
                states_denorm_acc[:, i, x1 + 8:x1 + 10] += states_denorm[:, j, y1 + 8:y1 + 10]
            else:
                # if connected by a corner
                x *= 2; y *= 2
                count[:, i, :, x:x + 2] += 1
                states_denorm_acc[:, i, x:x + 2] += states_denorm[:, j, y:y + 2]
                states_denorm_acc[:, i, x + 8:x + 10] += states_denorm[:, j, y + 8:y + 10]

    states_denorm = states_denorm_acc.view(B, N, 2, state_dim // 2) / count
    states_denorm = states_denorm.view(B, N, state_dim)

    return normalize([states_denorm], [stat[1]], var=True)[0] 