import torch
import torch.nn as nn
from collections import defaultdict

class FeatureGroupGNAN(nn.Module):
    def __init__(self, feature_groups, out_channels, num_layers, batch_size = 8, hidden_channels=None, bias=True, dropout=0.0,
                 device='cpu', rho_per_feature=True, normalize_rho=True, is_graph_task=False, readout_n_layers=1,
                 final_agg='sum', init_std=0.1, return_laplacian=False):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.bias = bias
        self.init_std = init_std
        self.dropout = dropout
        self.rho_per_feature = rho_per_feature
        self.normalize_rho = normalize_rho
        self.is_graph_task = is_graph_task
        self.readout_n_layers = readout_n_layers
        self.feature_groups = feature_groups
        self.batch_size = batch_size
        self.fs = nn.ModuleList()
        self.return_laplacian = return_laplacian

        for group_index, group in enumerate(feature_groups):
            curr_f = self._create_layers(
                in_channels=group, 
                out_channels=out_channels,
                hidden_channels=hidden_channels,
                bias=bias,
                num_layers=num_layers,
            )
            
            self.fs.append(nn.Sequential(*curr_f).to(device))

        m_bias = True
        if is_graph_task:
            m_bias = False

        rho_layers = []
        in_dim = 1
        for _ in range(num_layers):
            rho_layers.append(nn.Linear(in_dim, hidden_channels, bias=bias))
            # rho_layers.append(nn.BatchNorm1d(self.hidden_channels))
            # rho_layers.append(nn.GroupNorm(num_groups=1, num_channels=hidden_channels))
            rho_layers.append(nn.LeakyReLU())
            rho_layers.append(nn.Dropout(dropout))
            in_dim = hidden_channels
        rho_layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
        self.rho = nn.Sequential(*rho_layers).to(device)

        self._init_params()

    def _init_params(self):
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='leaky_relu')
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def _create_layers(self, in_channels, out_channels, hidden_channels, bias, num_layers):
        if num_layers == 1:
                curr_f = [nn.Linear(in_channels, out_channels, bias=bias)]
        else:

            curr_f = [
                nn.Linear(in_channels, hidden_channels, bias=bias),
                # nn.BatchNorm1d(hidden_channels),
                # nn.GroupNorm(num_groups=self.batch_size, num_channels=hidden_channels),
                nn.LayerNorm(hidden_channels),
                nn.ReLU(), 
                nn.Dropout(p=self.dropout)
            ]

            for j in range(1, num_layers - 1):
                curr_f.append(
                    nn.Linear(hidden_channels, hidden_channels, bias=bias)
                )
                # curr_f.append(nn.BatchNorm1d(hidden_channels))
                # curr_f.append(nn.GroupNorm(num_groups=self.batch_size, num_channels=hidden_channels))
                curr_f.append(nn.LayerNorm(hidden_channels))
                curr_f.append(nn.ReLU())
                curr_f.append(nn.Dropout(p=self.dropout))
            curr_f.append(nn.Linear(hidden_channels, out_channels, bias=bias))
        return curr_f

    def compute_laplacian_from_learned_dist(self, dist_matrix, sigma=1.0, normalized=True, device='cpu'):
        A = torch.exp(- (dist_matrix ** 2) / sigma**2) # Gaussian kernel (affinity)
        # A = torch.tril(A)  # only lower triangle

        A = 0.5 * (A + A.T) # try with symmetric A

        D = torch.diag(A.sum(dim=1))

        if normalized:
            d = D.diag().clamp(min=1e-3)
            d_inv_sqrt = torch.diag(1.0 / torch.sqrt(D.diag() + 1e-6))
            L = torch.eye(A.size(0)).to(device) - d_inv_sqrt @ A @ d_inv_sqrt
        else:
            L = D - A
        
        eigvals = torch.linalg.eigvalsh(L.cpu()).to(device)

        return L, eigvals

    def forward(self, x_batch, dist_batch, batch_vector, return_node_embeddings=False):
        N, _ = x_batch.shape
        fx = torch.empty(N, len(self.feature_groups), self.out_channels).to(self.device)
        start_idx = 0

        for group_index, group_size in enumerate(self.feature_groups):  
            feature_cols = x_batch[:, start_idx : start_idx + group_size]


            if group_size == 1:
                feature_cols = feature_cols.view(-1, 1)

            feature_cols = self.fs[group_index](feature_cols)

            fx[:, group_index] = feature_cols
            start_idx += group_size

        fx_perm = torch.permute(fx, (2, 0, 1))

        dist_embed = self.rho(dist_batch.flatten().view(-1, 1))
        dist_embed = dist_embed.view(N, N, self.out_channels)
        mask = (dist_batch >= 0)
        dist_embed[~mask] = 0.0
        m_dist = dist_embed.permute(2, 0, 1)
        
        mf = torch.matmul(m_dist, fx_perm)
        mf = mf.sum(dim=2)
        mf = mf.permute(1, 0)

        if self.is_graph_task:
            num_graphs = batch_vector.max().item() + 1
            out_graph = torch.zeros(num_graphs, mf.size(1), device=mf.device)
            graph_index = batch_vector.view(-1, 1).expand(-1, mf.size(1))
            out_graph.scatter_add_(0, graph_index, mf)

            if self.return_laplacian:
                laplacian_data = {}

                with torch.no_grad():
                    unique_graph_ids = batch_vector.unique()
                    for graph_id in unique_graph_ids:
                        node_mask = (batch_vector == graph_id)

                        learned_dist_scalar = dist_embed.mean(dim=-1)  # shape: [N, N]
                        learned_dist_scalar[~mask] = 0.0

                        # if node_mask.sum() < 2:
                        #     continue # skip graphs with fewer than 2 nodes

                        sub_dist_matrix = learned_dist_scalar[node_mask][:, node_mask]
                        L, eigvals = self.compute_laplacian_from_learned_dist(
                            sub_dist_matrix,
                            sigma=1.0,
                            normalized=True,
                            device=self.device
                        )
                        laplacian_data[int(graph_id)] = {
                            "laplacian": L,
                            "eigenvalues": eigvals,
                        }

                        if return_node_embeddings:
                            return out_graph, laplacian_data, mf

                        return out_graph, laplacian_data

            return out_graph

        return mf

class DeepSet(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, n_layers, device):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.device = device

        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())

        self.phi = nn.Sequential(*layers).to(device)
        self.readout = nn.Linear(hidden_dim, out_dim).to(device)

    def forward(self, x):
        x = self.phi(x)
        x = torch.mean(x, dim=0)
        x = self.readout(x)
        return x


class GMAN(nn.Module):
    def __init__(
            self,
            feature_groups,
            out_channels,
            n_layers,
            batch_size,
            biomarker_groups,  # list of biomarker groups i.e. [["biom1", "biom2"], ["biom3", "biom4"]]
            hidden_channels=None,
            bias=True,
            dropout=0.0,
            device='cpu',
            rho_per_feature=False,
            normalize_rho=True,
            is_graph_task=False,
            readout_n_layers=1,
            max_num_GNANs=3,
            n_biom_group_layers=3,
            same_GNAN_for_all=False,
            mix_feature_group_repres=False,
            return_laplacian=False,
            gnan_mode="per_group", # 'single', 'per_group', or 'per_biomarker'
            deepset_n_layers=2,
    ):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.n_layers = n_layers
        self.bias = bias
        self.dropout = dropout
        self.rho_per_feature = rho_per_feature
        self.normalize_rho = normalize_rho
        self.fs = nn.ModuleList()
        self.is_graph_task = is_graph_task
        self.readout_n_layers = readout_n_layers
        self.max_num_GNANs = max_num_GNANs
        self.same_gnan = same_GNAN_for_all
        self.feature_groups = feature_groups
        self.batch_size = batch_size
        self.biomarker_groups = biomarker_groups
        self.n_biom_group_layers = n_biom_group_layers
        self.mix_feature_group_repres = mix_feature_group_repres
        self.return_laplacian = return_laplacian
        self.deepset_n_layers = deepset_n_layers

        if gnan_mode == "single":
            self.gnans = nn.ModuleList([
                self.create_gnan().to(device)
            ])
            self.biomarker_to_gnan = defaultdict(lambda: 0)

        elif gnan_mode == "per_group":
            self.gnans = nn.ModuleList([
                self.create_gnan().to(device) for _ in biomarker_groups
            ])
            self.biomarker_to_gnan = {}
            for group_idx, group in enumerate(biomarker_groups):
                for biom in group:
                    self.biomarker_to_gnan[biom] = group_idx

        elif gnan_mode == "per_biomarker":
            all_biomarkers = [b for group in biomarker_groups for b in group]
            self.gnans = nn.ModuleList([
                self.create_gnan().to(device) for _ in all_biomarkers
            ])
            self.biomarker_to_gnan = {biom: i for i, biom in enumerate(all_biomarkers)}

        else:
            raise ValueError(f"Invalid gnan_mode: {gnan_mode}")

        self.readout = nn.Linear(max_num_GNANs, out_channels, bias=bias).to(device)

        self.group_deep_sets = nn.ModuleDict({
            str(i): DeepSet(in_dim=self.hidden_channels, out_dim=self.hidden_channels, hidden_dim=self.hidden_channels, n_layers=self.n_biom_group_layers, device=self.device).to(self.device) for i in range(len(biomarker_groups))
        })
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='leaky_relu')
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def create_gnan(self):
        return FeatureGroupGNAN(
            feature_groups=self.feature_groups,
            hidden_channels=self.hidden_channels,
            num_layers=self.n_layers,
            is_graph_task=self.is_graph_task,
            dropout=self.dropout,
            device=self.device,
            normalize_rho=self.normalize_rho,
            out_channels=self.hidden_channels,
            rho_per_feature=self.rho_per_feature,
            batch_size=self.batch_size,
            return_laplacian=self.return_laplacian
        )

    def create_mlp(self, num_layers):
        layers = []
        out_dim = self.hidden_channels
        for _ in range(num_layers):
            layers.append(nn.Linear(self.hidden_channels, self.hidden_channels))
            layers.append(nn.LayerNorm(self.hidden_channels))
            layers.append(nn.LeakyReLU())
            layers.append(nn.Dropout(self.dropout))
        layers.append(nn.Linear(self.hidden_channels, out_dim, bias=True))
        return nn.Sequential(*layers)

    def forward(self, inputs, batch_dim, return_group_outputs=False):
        assert (len(list(inputs.keys())) <= self.max_num_GNANs)
        biom_idx_map = {}
        outputs = torch.zeros(size=(self.max_num_GNANs, batch_dim, self.hidden_channels)).to(self.device)
        laplacian_outputs = {}

        for ind, (k, b) in enumerate(inputs.items()):
            biom_idx_map[k] = ind
            gnan_idx = self.biomarker_to_gnan[k]

            x_batch = b['x_batch'].to(self.device)
            dist_batch = b['dist_batch'].to(self.device)
            batch_vector = b['batch_vector'].to(self.device)

            if self.return_laplacian:
                out = self.gnans[gnan_idx](x_batch, dist_batch, batch_vector)

                gnan_output, lap_data = out
                laplacian_outputs[k] = lap_data
            else:
                gnan_output = self.gnans[gnan_idx](x_batch, dist_batch, batch_vector)
            expected_size = outputs[ind].shape[0]
            actual_size = gnan_output.shape[0]

            if actual_size < expected_size:
                pad_size = expected_size - actual_size
                pad_tensor = torch.zeros((pad_size, gnan_output.shape[1]), device=gnan_output.device)
                gnan_output = torch.cat([gnan_output, pad_tensor], dim=0)

            outputs[ind] = gnan_output

        groups_output = []
        outputs = outputs.permute(1, 0, 2)

        for group_idx, group in enumerate(self.biomarker_groups):
            group_out = []
            for biomarker in group:
                if biomarker in biom_idx_map:
                    biom_idx = biom_idx_map[biomarker]
                    group_out.append(outputs[:, biom_idx])

            if group_out and len(group_out) > 1:
                group_out = torch.stack(group_out, dim=0)
                group_out = self.group_deep_sets[str(group_idx)](group_out)
                groups_output.append(group_out)
            elif group_out and len(group_out) == 1:
                groups_output.append(group_out[0])

        groups_output = torch.stack(groups_output, dim=0)

        groups_output = groups_output.permute(1, 0, 2)
        outputs = groups_output.sum(dim=-1)

        if return_group_outputs:
            return groups_output, laplacian_outputs

        if self.return_laplacian:
            return outputs.sum(dim=-1), laplacian_outputs

        return outputs.sum(dim=-1)



class GMANP12(nn.Module):
    def __init__(
            self,
            feature_groups,
            out_channels,
            n_layers,
            batch_size,
            biomarker_groups,  # list of biomarker groups i.e. [["biom1", "biom2"], ["biom3", "biom4"]]
            hidden_channels=None,
            bias=True,
            dropout=0.0,
            device='cpu',
            rho_per_feature=False,
            normalize_rho=True,
            is_graph_task=False,
            readout_n_layers=1,
            max_num_GNANs=3,
            n_biom_group_layers=3,
            same_GNAN_for_all=False,
            mix_feature_group_repres=False,
            return_laplacian=False,
            gnan_mode="per_group", # 'single', 'per_group', or 'per_biomarker'
            deepset_n_layers=2,
    ):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.n_layers = n_layers
        self.bias = bias
        self.dropout = dropout
        self.rho_per_feature = rho_per_feature
        self.normalize_rho = normalize_rho
        self.fs = nn.ModuleList()
        self.is_graph_task = is_graph_task
        self.readout_n_layers = readout_n_layers
        self.max_num_GNANs = max_num_GNANs
        self.same_gnan = same_GNAN_for_all
        self.feature_groups = feature_groups
        self.batch_size = batch_size
        self.biomarker_groups = biomarker_groups
        self.n_biom_group_layers = n_biom_group_layers
        self.mix_feature_group_repres = mix_feature_group_repres
        self.return_laplacian = return_laplacian
        self.deepset_n_layers = deepset_n_layers

        if gnan_mode == "single":
            self.gnans = nn.ModuleList([
                self.create_gnan().to(device)
            ])
            self.biomarker_to_gnan = defaultdict(lambda: 0)

        elif gnan_mode == "per_group":
            self.gnans = nn.ModuleList([
                self.create_gnan().to(device) for _ in biomarker_groups
            ])
            self.biomarker_to_gnan = {}
            for group_idx, group in enumerate(biomarker_groups):
                for biom in group:
                    self.biomarker_to_gnan[biom] = group_idx

        elif gnan_mode == "per_biomarker":
            all_biomarkers = [b for group in biomarker_groups for b in group]
            self.gnans = nn.ModuleList([
                self.create_gnan().to(device) for _ in all_biomarkers
            ])
            self.biomarker_to_gnan = {biom: i for i, biom in enumerate(all_biomarkers)}

        else:
            raise ValueError(f"Invalid gnan_mode: {gnan_mode}")

        self.readout = nn.Linear(max_num_GNANs, out_channels, bias=bias).to(device)

        self.group_deep_sets = nn.ModuleDict({
            str(i): DeepSet(in_dim=self.hidden_channels, out_dim=self.hidden_channels, hidden_dim=self.hidden_channels, n_layers=self.n_biom_group_layers, device=self.device).to(self.device) for i in range(len(biomarker_groups))
        })
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='leaky_relu')
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def create_gnan(self):
        return FeatureGroupGNAN(
            feature_groups=self.feature_groups,
            hidden_channels=self.hidden_channels,
            num_layers=self.n_layers,
            is_graph_task=self.is_graph_task,
            dropout=self.dropout,
            device=self.device,
            normalize_rho=self.normalize_rho,
            out_channels=self.hidden_channels,
            rho_per_feature=self.rho_per_feature,
            batch_size=self.batch_size,
            return_laplacian=self.return_laplacian
        )

    def create_mlp(self, num_layers):
        layers = []
        # out_dim = 1 if self.mix_feature_group_repres else self.hidden_channels
        out_dim = self.hidden_channels
        for _ in range(num_layers):
            layers.append(nn.Linear(self.hidden_channels, self.hidden_channels))
            # layers.append(nn.GroupNorm(num_groups=self.batch_size, num_channels=self.hidden_channels))
            # layers.append(nn.BatchNorm1d(self.hidden_channels))
            layers.append(nn.LayerNorm(self.hidden_channels))
            layers.append(nn.LeakyReLU())
            layers.append(nn.Dropout(self.dropout))
        layers.append(nn.Linear(self.hidden_channels, out_dim, bias=True))
        return nn.Sequential(*layers)

    def forward(self, inputs, batch_dim, return_group_outputs=False):
        assert (len(list(inputs.keys())) <= self.max_num_GNANs)
        biom_idx_map = {}
        outputs = torch.zeros(size=(self.max_num_GNANs, batch_dim, self.hidden_channels)).to(self.device)
        laplacian_outputs = {}

        for ind, (k, b) in enumerate(inputs.items()):
            biom_idx_map[k] = ind
            gnan_idx = self.biomarker_to_gnan[k]
            # gnan_idx = ind

            x_batch = b['x_batch'].to(self.device)
            dist_batch = b['dist_batch'].to(self.device)
            batch_vector = b['batch_vector'].to(self.device)

            if self.return_laplacian:
                out = self.gnans[gnan_idx](x_batch, dist_batch, batch_vector)

                gnan_output, lap_data = out
                laplacian_outputs[k] = lap_data
            else:
                gnan_output = self.gnans[gnan_idx](x_batch, dist_batch, batch_vector)
            expected_size = outputs[ind].shape[0]
            actual_size = gnan_output.shape[0]

            if actual_size < expected_size:  # some people might not have had this biomarker measured
                pad_size = expected_size - actual_size
                pad_tensor = torch.zeros((pad_size, gnan_output.shape[1]), device=gnan_output.device)
                gnan_output = torch.cat([gnan_output, pad_tensor], dim=0)

            outputs[ind] = gnan_output

        groups_output = []
        outputs = outputs.permute(1, 0, 2)

        for group_idx, group in enumerate(self.biomarker_groups):
            group_out = []
            for biomarker in group:
                if biomarker in biom_idx_map:
                    biom_idx = biom_idx_map[biomarker]
                    group_out.append(outputs[:, biom_idx])

            if group_out and len(group_out) > 1:
                group_out = torch.stack(group_out, dim=0)
                # group_out = self.group_mlps[str(group_idx)](group_out)
                group_out = self.group_deep_sets[str(group_idx)](group_out)
                groups_output.append(group_out)
            elif group_out and len(group_out) == 1:
                groups_output.append(group_out[0])

        groups_output = torch.stack(groups_output, dim=0)

        groups_output = groups_output.permute(1, 0, 2)
        outputs = groups_output.sum(dim=-1)

        if return_group_outputs:
            return groups_output, laplacian_outputs

        if self.return_laplacian:
            return outputs.sum(dim=-1), laplacian_outputs

        return outputs.sum(dim=-1)