import torch
import torch.nn as nn

class FeatureGroupGNAN(nn.Module):
    def __init__(self, feature_groups, out_channels, num_layers, batch_size = 8, hidden_channels=None, bias=False, dropout=0.0,
                 device='cpu', rho_per_feature=True, normalize_rho=True, is_graph_task=False, readout_n_layers=1,
                 final_agg='sum', init_std=0.1, return_laplacian=False):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.bias = bias
        self.init_std = init_std
        self.dropout = dropout
        self.rho_per_feature = rho_per_feature
        self.normalize_rho = normalize_rho
        self.is_graph_task = is_graph_task
        self.readout_n_layers = readout_n_layers
        self.feature_groups = feature_groups
        self.batch_size = batch_size
        self.fs = nn.ModuleList()
        self.return_laplacian = return_laplacian

        for group_index, group in enumerate(feature_groups):
            curr_f = self._create_layers(
                in_channels=group, 
                out_channels=out_channels,
                hidden_channels=hidden_channels,
                bias=bias,
                num_layers=num_layers,
            )
            
            self.fs.append(nn.Sequential(*curr_f).to(device))

        m_bias = True
        if is_graph_task:
            m_bias = False

        rho_layers = []
        in_dim = 1
        for _ in range(num_layers):
            rho_layers.append(nn.Linear(in_dim, hidden_channels, bias=bias))
            # rho_layers.append(nn.BatchNorm1d(self.hidden_channels))
            # rho_layers.append(nn.GroupNorm(num_groups=hidden_channels, num_channels=hidden_channels))
            rho_layers.append(nn.LayerNorm(self.hidden_channels, bias=bias))
            rho_layers.append(nn.LeakyReLU())
            rho_layers.append(nn.Dropout(dropout))
            in_dim = hidden_channels
        rho_layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
        self.rho = nn.Sequential(*rho_layers).to(device)

        self._init_params()

    def _init_params(self):
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='leaky_relu')
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def _create_layers(self, in_channels, out_channels, hidden_channels, bias, num_layers):
        if num_layers == 1:
                curr_f = [nn.Linear(in_channels, out_channels, bias=bias)]
        else:

            curr_f = [
                nn.Linear(in_channels, hidden_channels, bias=bias),
                # nn.BatchNorm1d(hidden_channels),
                # nn.GroupNorm(num_groups=hidden_channels, num_channels=hidden_channels),
                nn.LayerNorm(hidden_channels, bias=bias),
                nn.ReLU(), 
                nn.Dropout(p=self.dropout)
            ]

            for j in range(1, num_layers - 1):
                curr_f.append(
                    nn.Linear(hidden_channels, hidden_channels, bias=bias)
                )
                # curr_f.append(nn.BatchNorm1d(hidden_channels))
                # curr_f.append(nn.GroupNorm(num_groups=hidden_channels, num_channels=hidden_channels))
                curr_f.append(nn.LayerNorm(hidden_channels, bias=bias))
                curr_f.append(nn.ReLU())
                curr_f.append(nn.Dropout(p=self.dropout))
            curr_f.append(nn.Linear(hidden_channels, out_channels, bias=bias))
        return curr_f

    def compute_laplacian_from_learned_dist(self, dist_matrix, sigma=1.0, normalized=True, device='cpu'):
        A = torch.exp(- (dist_matrix ** 2) / sigma**2) # Gaussian kernel (affinity)
        # A = torch.tril(A)  # only lower triangle

        A = 0.5 * (A + A.T) # try with symmetric A

        D = torch.diag(A.sum(dim=1))

        if normalized:
            d = D.diag().clamp(min=1e-3)
            d_inv_sqrt = torch.diag(1.0 / torch.sqrt(D.diag() + 1e-6))
            L = torch.eye(A.size(0)).to(device) - d_inv_sqrt @ A @ d_inv_sqrt
        else:
            L = D - A
        
        eigvals = torch.linalg.eigvalsh(L.cpu()).to(device)

        return L, eigvals

    def forward(self, x_batch, dist_batch, batch_vector):
        N, _ = x_batch.shape
        fx = torch.empty(N, len(self.feature_groups), self.out_channels).to(self.device)
        start_idx = 0

        for group_index, group_size in enumerate(self.feature_groups):  
            feature_cols = x_batch[:, start_idx : start_idx + group_size]


            if group_size == 1:
                feature_cols = feature_cols.view(-1, 1)

            feature_cols = self.fs[group_index](feature_cols)

            fx[:, group_index] = feature_cols
            start_idx += group_size

        fx_perm = torch.permute(fx, (2, 0, 1))

        dist_embed = self.rho(dist_batch.flatten().view(-1, 1))
        dist_embed = dist_embed.view(N, N, self.out_channels)
        mask = (dist_batch >= 0)
        dist_embed[~mask] = 0.0
        m_dist = dist_embed.permute(2, 0, 1)
        
        mf = torch.matmul(m_dist, fx_perm)
        mf = mf.sum(dim=2)
        mf = mf.permute(1, 0)

        if self.is_graph_task:
            num_graphs = batch_vector.max().item() + 1
            out_graph = torch.zeros(num_graphs, mf.size(1), device=mf.device)
            graph_index = batch_vector.view(-1, 1).expand(-1, mf.size(1))
            out_graph.scatter_add_(0, graph_index, mf)

            if self.return_laplacian:
                laplacian_data = {}

                with torch.no_grad():
                    unique_graph_ids = batch_vector.unique()
                    for graph_id in unique_graph_ids:
                        node_mask = (batch_vector == graph_id)

                        learned_dist_scalar = dist_embed.mean(dim=-1)  # shape: [N, N]
                        learned_dist_scalar[~mask] = 0.0

                        # if node_mask.sum() < 2:
                        #     continue # skip graphs with fewer than 2 nodes

                        sub_dist_matrix = learned_dist_scalar[node_mask][:, node_mask]
                        L, eigvals = self.compute_laplacian_from_learned_dist(
                            sub_dist_matrix,
                            sigma=1.0,
                            normalized=True,
                            device=self.device
                        )
                        laplacian_data[int(graph_id)] = {
                            "laplacian": L,
                            "eigenvalues": eigvals,
                        }

                        return out_graph, laplacian_data

            return out_graph

        return mf

# DeepSet aggregator (set -> vector)
class DeepSet(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, n_layers, device):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers += [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
        self.phi = nn.Sequential(*layers).to(device)
        self.readout = nn.Linear(hidden_dim, out_dim).to(device)

    def forward(self, x):  # x: [batch, set_size, in_dim]
        x = self.phi(x)         # [batch, set_size, hidden_dim]
        x = x.mean(dim=1)       # [batch, hidden_dim]
        return self.readout(x)  # [batch, out_dim]

# GMAN with two GNANs and original hidden collapse
class GMAN(nn.Module):
    def __init__(
        self,
        feature_groups,
        out_channels,
        n_layers,
        batch_size,
        biomarker_groups,          # {'single': [...], 'not_single': [...]}
        hidden_channels,
        bias=True,
        dropout=0.0,
        device='cpu',
        rho_per_feature=False,
        normalize_rho=True,
        is_graph_task=False,
        n_biom_group_layers=3,
    ):
        super().__init__()
        self.device = device

        # two independent GNANs
        self.gnans = nn.ModuleDict({
            'single': FeatureGroupGNAN(
                feature_groups=feature_groups,
                out_channels=hidden_channels,
                num_layers=n_layers,
                batch_size=batch_size,
                hidden_channels=hidden_channels,
                bias=bias,
                dropout=dropout,
                device=device,
                rho_per_feature=rho_per_feature,
                normalize_rho=normalize_rho,
                is_graph_task=is_graph_task
            ).to(device),
            'not_single': FeatureGroupGNAN(
                feature_groups=feature_groups,
                out_channels=hidden_channels,
                num_layers=n_layers,
                batch_size=batch_size,
                hidden_channels=hidden_channels,
                bias=bias,
                dropout=dropout,
                device=device,
                rho_per_feature=rho_per_feature,
                normalize_rho=normalize_rho,
                is_graph_task=is_graph_task
            ).to(device),
        })

        # DeepSet for 'single' group
        self.deep_set = DeepSet(
            in_dim=hidden_channels,
            out_dim=hidden_channels,
            hidden_dim=hidden_channels,
            n_layers=n_biom_group_layers,
            device=device
        ).to(device)

        # init weights
        for name, p in self.named_parameters():
            if 'weight' in name and p.dim() >= 2:
                nn.init.kaiming_normal_(p, nonlinearity='leaky_relu')
            elif 'bias' in name:
                nn.init.constant_(p, 0)

    def forward(self, inputs, batch_dim, return_group_outputs = False):
        # "single" path → GNAN then DeepSet
        s = self.gnans['single'](
            inputs['single']['x_batch'].to(self.device),
            inputs['single']['dist_batch'].to(self.device),
            inputs['single']['batch_vector'].to(self.device),
        )
        s = self.deep_set(s.unsqueeze(1))  # [batch, hidden]

        # "not_single" path → GNAN only
        ns = self.gnans['not_single'](
            inputs['not_single']['x_batch'].to(self.device),
            inputs['not_single']['dist_batch'].to(self.device),
            inputs['not_single']['batch_vector'].to(self.device),
        )  # [batch, hidden]

        group_outputs = torch.stack([s, ns], dim=1)  # [batch, 2, hidden]
        if return_group_outputs:
            return group_outputs

        combined = s + ns                    # [batch, hidden]
        return combined.sum(dim=-1)         # [batch]