import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree, to_networkx
from networkx.algorithms.shortest_paths.unweighted import single_source_shortest_path_length
from .gnns import *
import scipy.sparse
import numpy as np
from torch.autograd import Variable
import torch.autograd as autograd
import math
import torch.distributions as D


class MaxLogits(nn.Module):
    def __init__(self, d, c, cfg):
        super(MaxLogits, self).__init__()
        if cfg["backbone"] == 'gcn':
            self.encoder = GCN(in_channels=d,
                        hidden_channels=cfg["hidden_channels"],
                        out_channels=c,
                        num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"],
                        use_bn=cfg["use_bn"])
        elif cfg["backbone"] == 'mlp':
            self.encoder = MLP(in_channels=d, hidden_channels=cfg["hidden_channels"],
                        out_channels=c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"])
        elif cfg["backbone"] == 'appnp':
            self.encoder = APPNP_Net(d, cfg["hidden_channels"], c, dropout=cfg["dropout"])
        elif cfg["backbone"] == 'gat':
            self.encoder = GAT(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"], use_bn=cfg["use_bn"], heads=cfg["gat_heads"], out_heads=cfg["out_heads"])
        else:
            raise NotImplementedError

    def reset_parameters(self):
        self.encoder.reset_parameters()

    def forward(self, dataset, device):
        x, edge_index = dataset.x.to(device), dataset.edge_index.to(device)
        return self.encoder(x, edge_index)

    def detect(self, dataset, node_idx, device, cfg):

        logits = self.encoder(dataset.x.to(device), dataset.edge_index.to(device))[node_idx]
        if cfg["dataset"] in ('proteins', 'ppi'):
            pred = torch.sigmoid(logits).unsqueeze(-1)
            pred = torch.cat([pred, 1- pred], dim=-1)
            max_logits = pred.max(dim=-1)[0]
            return max_logits.sum(dim=1)
        else:
            return logits.max(dim=1)[0]

    def loss_compute(self, dataset_ind, dataset_ood, criterion, device, cfg):

        train_idx = dataset_ind.splits['train']
        logits_in = self.encoder(dataset_ind.x.to(device), dataset_ind.edge_index.to(device))[train_idx]
        if cfg["dataset"] in ('proteins', 'ppi'):
            loss = criterion(logits_in, dataset_ind.y[train_idx].to(device).to(torch.float))
        else:
            pred_in = F.log_softmax(logits_in, dim=1)
            loss = criterion(pred_in, dataset_ind.y[train_idx].squeeze(1).to(device))
        return loss


class EnergyModel(nn.Module):
    def __init__(self, d, c, cfg):
        super(EnergyModel, self).__init__()
        if cfg["backbone"] == 'gcn':
            self.encoder = GCN(in_channels=d,
                        hidden_channels=cfg["hidden_channels"],
                        out_channels=c,
                        num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"],
                        use_bn=cfg["use_bn"])
        elif cfg["backbone"] == 'mlp':
            self.encoder = MLP(in_channels=d, hidden_channels=cfg["hidden_channels"],
                        out_channels=c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"])
        elif cfg["backbone"] == 'sgc':
            self.encoder = SGC(in_channels=d, out_channels=c, hops=cfg["hops"])
        elif cfg["backbone"] == 'gat':
            self.encoder = GAT(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"], use_bn=cfg["use_bn"], heads=cfg["gat_heads"], out_heads=cfg["out_heads"])
        else:
            raise NotImplementedError

    def reset_parameters(self):
        self.encoder.reset_parameters()

    def forward(self, dataset, device):
        x, edge_index = dataset.x.to(device), dataset.edge_index.to(device)
        return self.encoder(x, edge_index)

    def detect(self, dataset, node_idx, device, cfg):

        logits = self.encoder(dataset.x.to(device), dataset.edge_index.to(device))[node_idx]
        if cfg["dataset"] in ('proteins', 'ppi'):
            logits = torch.stack([logits, torch.zeros_like(logits)], dim=2)
            neg_energy = cfg["T"] * torch.logsumexp(logits / cfg["T"], dim=-1).sum(dim=1)
        else:
            neg_energy = cfg["T"] * torch.logsumexp(logits / cfg["T"], dim=-1)
        return neg_energy

    def loss_compute(self, dataset_ind, dataset_ood, criterion, device, cfg):

        train_in_idx, train_ood_idx = dataset_ind.splits['train'], dataset_ood.node_idx

        logits_in = self.encoder(dataset_ind.x.to(device), dataset_ind.edge_index.to(device))[train_in_idx]
        logits_out = self.encoder(dataset_ood.x.to(device), dataset_ood.edge_index.to(device))[train_ood_idx]

        if cfg["dataset"] in ('proteins', 'ppi'):
            sup_loss = criterion(logits_in, dataset_ind.y[train_in_idx].to(device).to(torch.float))
        else:
            pred_in = F.log_softmax(logits_in, dim=1)
            sup_loss = criterion(pred_in, dataset_ind.y[train_in_idx].squeeze(1).to(device))

        '''if cfg["dataset"] in ('proteins', 'ppi'):
            logits_in = torch.stack([logits_in, torch.zeros_like(logits_in)], dim=2)
            logits_out = torch.stack([logits_out, torch.zeros_like(logits_out)], dim=2)
            energy_in = - cfg["T"] * torch.logsumexp(logits_in / cfg["T"], dim=-1).sum(dim=1)
            energy_out = - cfg["T"] * torch.logsumexp(logits_out / cfg["T"], dim=-1).sum(dim=1)
        else:
            energy_in = - cfg["T"] * torch.logsumexp(logits_in / cfg["T"], dim=-1)
            energy_out = - cfg["T"] * torch.logsumexp(logits_out / cfg["T"], dim=-1)
        if energy_in.shape[0] != energy_out.shape[0]:
            min_n = min(energy_in.shape[0], energy_out.shape[0])
            energy_in = energy_in[:min_n]
            energy_out = energy_out[:min_n]
        print(energy_in.mean().data, energy_out.mean().data)
        reg_loss = torch.mean(F.relu(energy_in - cfg["m_in"]) ** 2 + F.relu(cfg["m_out"] - energy_out) ** 2)
        # reg_loss = torch.mean(F.relu(energy_in - energy_out - cfg["m"]) ** 2)

        loss = sup_loss + cfg["lamda"] * reg_loss'''
        loss = sup_loss

        return loss



class EnergyProp(nn.Module):
    def __init__(self, d, c, cfg):
        super(EnergyProp, self).__init__()
        if cfg["backbone"] == 'gcn':
            self.encoder = GCN(in_channels=d,
                        hidden_channels=cfg["hidden_channels"],
                        out_channels=c,
                        num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"],
                        use_bn=cfg["use_bn"])
        elif cfg["backbone"] == 'mlp':
            self.encoder = MLP(in_channels=d, hidden_channels=cfg["hidden_channels"],
                        out_channels=c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"])
        elif cfg["backbone"] == 'sgc':
            self.encoder = SGC(in_channels=d, out_channels=c, hops=cfg["hops"])
        elif cfg["backbone"] == 'gat':
            self.encoder = GAT(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"], use_bn=cfg["use_bn"], heads=cfg["gat_heads"], out_heads=cfg["out_heads"])
        else:
            raise NotImplementedError

    def reset_parameters(self):
        self.encoder.reset_parameters()

    def forward(self, dataset, device):
        x, edge_index = dataset.x.to(device), dataset.edge_index.to(device)
        return self.encoder(x, edge_index)

    def propagation(self, e, edge_index, l=1, alpha=0.5):
        e = e.unsqueeze(1)
        N = e.shape[0]
        row, col = edge_index
        d = degree(col, N).float()
        d_norm = 1. / d[col]
        value = torch.ones_like(row) * d_norm
        value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
        adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
        for _ in range(l):
            e = e * alpha + matmul(adj, e) * (1 - alpha)
        return e.squeeze(1)

    def detect(self, dataset, node_idx, device, cfg):

        x, edge_index = dataset.x.to(device), dataset.edge_index.to(device)
        logits = self.encoder(x, edge_index)
        if cfg["dataset"] in ('proteins', 'ppi'):
            logits = torch.stack([logits, torch.zeros_like(logits)], dim=2)
            neg_energy = cfg["T"] * torch.logsumexp(logits / cfg["T"], dim=-1).sum(dim=1)
        else:
            neg_energy = cfg["T"] * torch.logsumexp(logits / cfg["T"], dim=-1)
        neg_energy_prop = self.propagation(neg_energy, edge_index, cfg["prop_layers"], cfg["alpha"])
        return neg_energy_prop[node_idx]

    def loss_compute(self, dataset_ind, dataset_ood, criterion, device, cfg):
        x_in, edge_index_in = dataset_ind.x.to(device), dataset_ind.edge_index.to(device)
        x_out, edge_index_out = dataset_ood.x.to(device), dataset_ood.edge_index.to(device)
        logits_in = self.encoder(x_in, edge_index_in)
        logits_out = self.encoder(x_out, edge_index_out)

        train_in_idx, train_ood_idx = dataset_ind.splits['train'], dataset_ood.node_idx

        if cfg["dataset"] in ('proteins', 'ppi'):
            sup_loss = criterion(logits_in[train_in_idx], dataset_ind.y[train_in_idx].to(device).to(torch.float))
        else:
            pred_in = F.log_softmax(logits_in[train_in_idx], dim=1)
            sup_loss = criterion(pred_in, dataset_ind.y[train_in_idx].squeeze(1).to(device))

        '''if cfg["dataset"] in ('proteins', 'ppi'):
            logits_in = torch.stack([logits_in, torch.zeros_like(logits_in)], dim=2)
            logits_out = torch.stack([logits_out, torch.zeros_like(logits_out)], dim=2)
            energy_in = - cfg["T"] * torch.logsumexp(logits_in / cfg["T"], dim=-1).sum(dim=1)
            energy_out = - cfg["T"] * torch.logsumexp(logits_out / cfg["T"], dim=-1).sum(dim=1)
        else:
            energy_in = - cfg["T"] * torch.logsumexp(logits_in / cfg["T"], dim=-1)
            energy_out = - cfg["T"] * torch.logsumexp(logits_out / cfg["T"], dim=-1)
        energy_prop_in = self.propagation(energy_in, edge_index_in, cfg["prop_layers"], cfg["alpha"])[train_in_idx]
        energy_prop_out = self.propagation(energy_out, edge_index_out, cfg["prop_layers"], cfg["alpha"])[train_ood_idx]

        if energy_prop_in.shape[0] != energy_prop_out.shape[0]:
            min_n = min(energy_prop_in.shape[0], energy_prop_out.shape[0])
            energy_prop_in = energy_prop_in[:min_n]
            energy_prop_out = energy_prop_out[:min_n]
        print(energy_prop_in.mean().data, energy_prop_out.mean().data)
        reg_loss = torch.mean(F.relu(energy_prop_in - cfg["m_in"]) ** 2 + F.relu(cfg["m_out"] - energy_prop_out) ** 2)
        # reg_loss = torch.mean(F.relu(energy_prop_in - energy_prop_out - cfg["m"]) ** 2)

        loss = sup_loss + cfg["lamda"] * reg_loss'''
        loss = sup_loss

        return loss

gpn_params = dict()
gpn_params["dim_hidden"] = 64
gpn_params["dropout_prob"] = 0.5
gpn_params["K"] = 10
gpn_params["add_self_loops"] = True
gpn_params["maf_layers"] = 0
gpn_params["gaussian_layers"] = 0
gpn_params["use_batched_flow"] = True
gpn_params["loss_reduction"] = 'sum'
gpn_params["approximate_reg"] = True
gpn_params["factor_flow_lr"] = None
gpn_params["flow_weight_decay"] = 0.0
gpn_params["pre_train_mode"] = 'flow'
gpn_params["alpha_evidence_scale"] = 'latent-new'
gpn_params["alpha_teleport"] = 0.1
gpn_params["entropy_reg"] = 0.0001
gpn_params["dim_latent"] = 32
gpn_params["radial_layers"] = 10
gpn_params["likelihood_type"] = None

from models.gpn.layers import APPNPPropagation
from models.gpn.layers import Density, Evidence
from models.gpn.utils import Prediction, apply_mask


def uce_loss(alpha, y, reduction='mean'):
    """Calculates the unified cross entropy loss.
    
    Args:
        alpha (Tensor): Dirichlet alpha parameters
        y (Tensor): Ground truth labels
        reduction (str, optional): Reduction method. Defaults to 'mean'.
        
    Returns:
        Tensor: The UCE loss value
    """
    S = alpha.sum(dim=1, keepdim=True)
    p = alpha / S
    
    # Gather ground truth label probabilities
    y_one_hot = F.one_hot(y.squeeze(), num_classes=alpha.size(1)).float()
    A = torch.sum(y_one_hot * (torch.digamma(S) - torch.digamma(alpha)), dim=1)
    
    if reduction == 'none':
        return A
    elif reduction == 'sum':
        return A.sum()
    else:  # Default to mean
        return A.mean()


def entropy_reg(alpha, coef=0.0001, approximate=False, reduction='mean'):
    """Entropy regularization to encourage uncertainty.
    
    Args:
        alpha (Tensor): Dirichlet alpha parameters
        coef (float, optional): Regularization coefficient. Defaults to 0.0001.
        approximate (bool, optional): Use approximation. Defaults to False.
        reduction (str, optional): Reduction method. Defaults to 'mean'.
        
    Returns:
        Tensor: The entropy regularization value
    """
    if approximate:
        S = alpha.sum(dim=1)
        K = alpha.size(1)
        A = -coef * (torch.lgamma(S) - torch.sum(torch.lgamma(alpha), dim=1) - 
                    torch.lgamma(torch.tensor(K, dtype=S.dtype, device=S.device)) + 
                    torch.sum((alpha - 1.0) * (torch.digamma(alpha) - torch.digamma(S.unsqueeze(1))), dim=1))
    else:
        dir = D.Dirichlet(alpha)
        A = -coef * dir.entropy()
    
    if reduction == 'none':
        return A
    elif reduction == 'sum':
        return A.sum()
    else:  # Default to mean
        return A.mean()


class GPN(nn.Module):
    def __init__(self, d, c, cfg):
        super(GPN, self).__init__()
        self.params = gpn_params
        self.params["dim_feature"] = d
        self.params["num_classes"] = c

        self.input_encoder = nn.Sequential(
            nn.Linear(d, self.params["dim_hidden"]),
            nn.ReLU(),
            nn.Dropout(p=self.params["dropout_prob"]))

        self.latent_encoder = nn.Linear(self.params["dim_hidden"], self.params["dim_latent"])

        use_batched = True if self.params["use_batched_flow"] else False
        self.flow = Density(
            dim_latent=self.params["dim_latent"],
            num_mixture_elements=c,
            radial_layers=self.params["radial_layers"],
            maf_layers=self.params["maf_layers"],
            gaussian_layers=self.params["gaussian_layers"],
            use_batched_flow=use_batched)

        self.evidence = Evidence(scale=self.params["alpha_evidence_scale"])

        self.propagation = APPNPPropagation(
            K=self.params["K"],
            alpha=self.params["alpha_teleport"],
            add_self_loops=self.params["add_self_loops"],
            cached=False,
            normalization='sym')

        self.detect_type = cfg["GPN_detect_type"]

        assert self.detect_type in ('Alea', 'Epist', 'Epist_wo_Net')
        assert self.params["pre_train_mode"] in ('encoder', 'flow', None)
        assert self.params['likelihood_type'] in ('UCE', 'nll_train', 'nll_train_and_val', 'nll_consistency', None)

    def reset_parameters(self):
        self.input_encoder = nn.Sequential(
            nn.Linear(self.params["dim_feature"], self.params["dim_hidden"]),
            nn.ReLU(),
            nn.Dropout(p=self.params["dropout_prob"]))

        self.latent_encoder = nn.Linear(self.params["dim_hidden"], self.params["dim_latent"])

        use_batched = True if self.params["use_batched_flow"] else False
        self.flow = Density(
            dim_latent=self.params["dim_latent"],
            num_mixture_elements=self.params["num_classes"],
            radial_layers=self.params["radial_layers"],
            maf_layers=self.params["maf_layers"],
            gaussian_layers=self.params["gaussian_layers"],
            use_batched_flow=use_batched)

        self.evidence = Evidence(scale=self.params["alpha_evidence_scale"])

        self.propagation = APPNPPropagation(
            K=self.params["K"],
            alpha=self.params["alpha_teleport"],
            add_self_loops=self.params["add_self_loops"],
            cached=False,
            normalization='sym')

    def forward(self, dataset, device):
        pred =  self.forward_impl(dataset, device)
        return pred.hard.unsqueeze(-1)

    def forward_impl(self, dataset, device):
        edge_index = dataset.edge_index.to(device) if dataset.edge_index is not None else dataset.adj_t.to(device)
        x = dataset.x.to(device)
        h = self.input_encoder(x)
        z = self.latent_encoder(h)

        # compute feature evidence (with Normalizing Flows)
        # log p(z, c) = log p(z | c) p(c)
        if self.training:
            p_c = self.get_class_probalities(dataset).to(device)
            self.p_c = p_c
        else:
            p_c = self.p_c
        log_q_ft_per_class = self.flow(z) + p_c.view(1, -1).log()

        if '-plus-classes' in self.params["alpha_evidence_scale"]:
            further_scale = self.params["num_classes"]
        else:
            further_scale = 1.0

        beta_ft = self.evidence(
            log_q_ft_per_class, dim=self.params["dim_latent"],
            further_scale=further_scale).exp()

        alpha_features = 1.0 + beta_ft

        beta = self.propagation(beta_ft, edge_index)
        alpha = 1.0 + beta

        soft = alpha / alpha.sum(-1, keepdim=True)
        logits = None
        log_soft = torch.log(soft)
        max_soft, hard = soft.max(dim=-1)

        # ---------------------------------------------------------------------------------
        pred = Prediction(
            # Basic predictions
            soft=soft,
            hard=hard,
            
            # Alpha parameters
            alpha=alpha,
            
            # Confidence scores
            sample_confidence_aleatoric=max_soft,
            sample_confidence_epistemic=alpha.sum(-1),
            sample_confidence_features=alpha_features.sum(-1),
            sample_confidence_neighborhood=None,
            sample_confidence_structure=None,
            
            # Prediction confidence
            prediction_confidence_aleatoric=max_soft,
            prediction_confidence_epistemic=alpha[torch.arange(hard.size(0)), hard],
            
            # Additional parameters
            logits=logits,
            evidence=beta.sum(-1)
        )
        # ---------------------------------------------------------------------------------

        return pred

    def get_optimizer(self, lr: float, weight_decay: float):
        flow_lr = lr if self.params["factor_flow_lr"] is None else self.params["factor_flow_lr"] * lr
        flow_weight_decay = weight_decay if self.params["flow_weight_decay"] is None else self.params["flow_weight_decay"]

        flow_params = list(self.flow.named_parameters())
        flow_param_names = [f'flow.{p[0]}' for p in flow_params]
        flow_param_weights = [p[1] for p in flow_params]

        all_params = list(self.named_parameters())
        params = [p[1] for p in all_params if p[0] not in flow_param_names]

        # all params except for flow
        flow_optimizer = torch.optim.Adam(flow_param_weights, lr=flow_lr, weight_decay=flow_weight_decay)
        model_optimizer = torch.optim.Adam(
            [{'params': flow_param_weights, 'lr': flow_lr, 'weight_decay': flow_weight_decay},
             {'params': params}],
            lr=lr, weight_decay=weight_decay)

        return model_optimizer, flow_optimizer

    def get_warmup_optimizer(self, lr: float, weight_decay: float):
        model_optimizer, flow_optimizer = self.get_optimizer(lr, weight_decay)

        if self.params["pre_train_mode"] == 'encoder':
            warmup_optimizer = model_optimizer
        else:
            warmup_optimizer = flow_optimizer

        return warmup_optimizer

    def loss_compute(self, dataset_ind, dataset_ood, criterion, device, cfg):
        train_in_idx = dataset_ind.splits['train']
        prediction = self.forward_impl(dataset_ind, device)
        y = dataset_ind.y[train_in_idx].to(device)
        alpha_train = prediction.alpha[train_in_idx]
        reg = self.params["entropy_reg"]
        return uce_loss(alpha_train, y, reduction=self.params["loss_reduction"]) + entropy_reg(alpha_train, reg,
                                                                                               approximate=True,
                                                                                               reduction=self.params[
                                                                                                   "loss_reduction"])

    def valid_loss(self, dataset_ind, device):
        val_idx = dataset_ind.splits['valid']
        prediction = self.forward_impl(dataset_ind, device)
        y = dataset_ind.y[val_idx].to(device)
        alpha_train = prediction.alpha[val_idx]
        reg = self.params["entropy_reg"]
        return uce_loss(alpha_train, y, reduction=self.params["loss_reduction"]) + entropy_reg(alpha_train, reg,
                                                                                               approximate=True,
                                                                                               reduction=self.params[
                                                                                                   "loss_reduction"])

    def detect(self, dataset, node_idx, device, cfg):
        pred = self.forward_impl(dataset, device)
        if self.detect_type == 'Alea':
            score = pred.sample_confidence_aleatoric[node_idx]
        elif self.detect_type == 'Epist':
            score = pred.sample_confidence_epistemic[node_idx]
        elif self.detect_type == 'Epist_wo_Net':
            score = pred.sample_confidence_features[node_idx]
        else:
            raise ValueError(f"Unknown detect type {self.detect_type}")

        return score

    def get_class_probalities(self, data):
        l_c = torch.zeros(self.params["num_classes"], device=data.x.device)
        train_idx = data.splits['train']
        y_train = data.y[train_idx]

        # calculate class_counts L(c)
        for c in range(self.params["num_classes"]):
            class_count = (y_train == c).int().sum()
            l_c[c] = class_count

        L = l_c.sum()
        p_c = l_c / L

        return p_c

'''sgcn_params = dict()
sgcn_params["seed"] = 42
sgcn_params["dim_hidden"] = 16
sgcn_params["dropout_prob"] = 0.5
sgcn_params["use_kernel"] = True
sgcn_params["lambda_1"] = 0.001
sgcn_params["teacher_training"] = True
sgcn_params["use_bayesian_dropout"] = False
sgcn_params["sample_method"] = 'log_evidence'
sgcn_params["num_samples_dropout"] = 10
sgcn_params["loss_reduction"] = None'''

import torch.distributions as D
from models.gpn.utils import Prediction

def loss_reduce(loss, reduction='mean'):
    """Utility function to reduce losses according to specified reduction method"""
    if reduction == 'sum':
        return loss.sum()
    elif reduction == 'mean':
        return loss.mean()
    else:
        return loss

def bayesian_risk_sosq(alpha, y, reduction='mean'):
    """Compute Bayesian risk sum of squares for Dirichlet-Categorical model"""
    # Convert labels to one-hot
    y_one_hot = F.one_hot(y.squeeze(), num_classes=alpha.size(1)).float()
    
    # Compute expected risk
    S = alpha.sum(dim=1, keepdim=True)
    p = alpha / S
    
    # Sum of squares risk: E[(y - p)^2] = y^2 - 2yp + E[p^2]
    # For Dirichlet, E[p_i^2] = (alpha_i * (alpha_i + 1)) / (S * (S + 1))
    expected_p_squared = alpha * (alpha + 1) / (S * (S + 1))
    risk = (y_one_hot - 2 * y_one_hot * p + expected_p_squared).sum(dim=1)
    
    # Apply reduction
    if reduction == 'sum':
        return risk.sum()
    elif reduction == 'mean':
        return risk.mean()
    else:
        return risk


class SGCN(nn.Module):
    def __init__(self, d, c, cfg):
        super(SGCN, self).__init__()
        self.params = dict()
        self.params = dict()
        self.params["seed"] = cfg["gkde_seed"]
        self.params["dim_hidden"] = cfg["gkde_dim_hidden"]
        self.params["dropout_prob"] = cfg["gkde_dropout_prob"]
        self.params["use_kernel"] = bool(cfg["gkde_use_kernel"])
        self.params["lambda_1"] = cfg["gkde_lambda_1"]
        self.params["teacher_training"] = bool(cfg["gkde_teacher_training"])
        self.params["use_bayesian_dropout"] = bool(cfg["gkde_use_bayesian_dropout"])
        self.params["sample_method"] = cfg["gkde_sample_method"]
        self.params["num_samples_dropout"] = cfg["gkde_num_samples_dropout"]
        self.params["loss_reduction"] = cfg["gkde_loss_reduction"]

        self.params["dim_feature"] = d
        self.params["num_classes"] = c

        self.alpha_prior = None
        self.y_teacher = None

        # Define the GCN layers
        # Using this simplified version instead of the GCNConv from gpn
        self.conv1 = nn.Sequential(
            nn.Linear(d, self.params["dim_hidden"]),
            nn.ReLU(),
            nn.Dropout(p=self.params["dropout_prob"])
        )
        
        self.conv2 = nn.Linear(self.params["dim_hidden"], c)

        self.evidence_activation = torch.exp
        self.epoch = None
        self.teacher = None  # Will be set in create_storage

        self.detect_type = cfg["GPN_detect_type"]

        assert self.detect_type in ('Alea', 'Epist')

    def reset_parameters(self):
        self.alpha_prior = None
        self.y_teacher = None

        # Reset the weights of the layers
        nn.init.xavier_uniform_(self.conv1[0].weight)
        nn.init.zeros_(self.conv1[0].bias)
        nn.init.xavier_uniform_(self.conv2.weight)
        nn.init.zeros_(self.conv2.bias)

        self.evidence_activation = torch.exp
        self.epoch = None

    def forward(self, dataset, device):
        pred = self.forward_impl(dataset, device)
        return pred.hard.unsqueeze(-1)

    def forward_impl(self, dataset, device):
        edge_index = dataset.edge_index.to(device) if dataset.edge_index is not None else dataset.adj_t.to(device)
        x = dataset.x.to(device)
        
        # First layer
        h = self.conv1[0](x)  # Linear
        h = self.conv1[1](h)  # ReLU
        
        # Apply dropout during training or if using Bayesian dropout
        if self.training or (not self.params["use_bayesian_dropout"]):
            h = self.conv1[2](h)  # Dropout
            x = self.conv2(h)  # Second layer
            evidence = self.evidence_activation(x)
        else:
            # Bayesian Monte Carlo dropout
            self_training = self.training
            self.train()
            samples = [None] * self.params["num_samples_dropout"]

            for i in range(self.params["num_samples_dropout"]):
                h_i = self.conv1[2](h)  # Apply dropout
                x_i = self.conv2(h_i)  # Second layer
                samples[i] = x_i

            log_evidence = torch.stack(samples, dim=1)

            if self.params["sample_method"] == 'log_evidence':
                log_evidence = log_evidence.mean(dim=1)
                evidence = self.evidence_activation(log_evidence)
            elif self.params["sample_method"] == 'alpha':
                evidence = self.evidence_activation(log_evidence)
                evidence = evidence.mean(dim=1)
            else:
                raise AssertionError

            if self_training:
                self.train()
            else:
                self.eval()

        alpha = 1.0 + evidence
        soft = alpha / alpha.sum(-1, keepdim=True)
        log_soft = torch.log(soft)
        max_soft, hard = soft.max(dim=-1)

        # ---------------------------------------------------------------------------------
        pred = Prediction(
            # Basic predictions
            soft=soft,
            hard=hard,
            
            # Alpha parameters
            alpha=alpha,
            
            # Confidence scores
            sample_confidence_aleatoric=max_soft,
            sample_confidence_epistemic=alpha.sum(-1),
            sample_confidence_features=None,
            sample_confidence_neighborhood=None,
            sample_confidence_structure=None,
            
            # Prediction confidence
            prediction_confidence_aleatoric=max_soft,
            prediction_confidence_epistemic=alpha[torch.arange(hard.size(0)), hard],
            
            # Additional parameters
            logits=None,
            evidence=evidence.sum(-1)
        )
        # ---------------------------------------------------------------------------------

        return pred

    def loss_compute(self, dataset_ind, dataset_ood, criterion, device, cfg):
        if self.params["loss_reduction"] in ('sum', None):
            n_nodes = 1.0
            frac_train = 1.0
        else:
            n_nodes = dataset_ind.y.size(0)
            frac_train = dataset_ind.train_mask.float().mean()

        prediction = self.forward_impl(dataset_ind, device)

        alpha = prediction.alpha
        # bayesian risk of sum of squares
        alpha_train = alpha[dataset_ind.splits['train']]
        y = dataset_ind.y[dataset_ind.splits['train']]
        bay_risk = bayesian_risk_sosq(alpha_train, y.to(device), reduction='sum')
        losses = {'BR': bay_risk * 1.0 / (n_nodes * frac_train)}

        # KL divergence w.r.t. alpha-prior from Gaussian Dirichlet Kernel
        if self.params["use_kernel"]:
            dirichlet = D.Dirichlet(alpha)
            alpha_prior = self.alpha_prior.to(alpha.device).detach()
            dirichlet_prior = D.Dirichlet(alpha_prior)
            KL_prior = D.kl.kl_divergence(dirichlet, dirichlet_prior)
            KL_prior = loss_reduce(KL_prior, reduction='sum')
            losses['KL_prior'] = self.params["lambda_1"] * KL_prior / n_nodes
        else:
            losses['KL_prior'] = torch.tensor(0.0, device=device)

        # KL divergence for teacher training
        if self.params["teacher_training"] and self.y_teacher is not None:
            # currently only works for full-batch training
            # i.e. epochs == iterations
            if self.training:
                if self.epoch is None:
                    self.epoch = 0
                else:
                    self.epoch += 1

            y_teacher = self.y_teacher.to(prediction.soft.device).detach()
            lambda_2 = min(1.0, self.epoch * 1.0 / 200)
            categorical_pred = D.Categorical(prediction.soft)
            categorical_teacher = D.Categorical(y_teacher)
            KL_teacher = D.kl.kl_divergence(categorical_pred, categorical_teacher)
            KL_teacher = loss_reduce(KL_teacher, reduction='sum')
            losses['KL_teacher'] = lambda_2 * KL_teacher / n_nodes
        else:
            losses['KL_teacher'] = torch.tensor(0.0, device=device)

        return losses['BR'] + losses['KL_prior'] + losses['KL_teacher']

    def valid_loss(self, dataset_ind, device):
        if self.params["loss_reduction"] in ('sum', None):
            n_nodes = 1.0
            frac_train = 1.0
        else:
            n_nodes = dataset_ind.y.size(0)
            frac_train = dataset_ind.splits['valid'].float().mean()

        prediction = self.forward_impl(dataset_ind, device)

        alpha = prediction.alpha
        # bayesian risk of sum of squares
        alpha_train = alpha[dataset_ind.splits['valid']]
        y = dataset_ind.y[dataset_ind.splits['valid']]
        bay_risk = bayesian_risk_sosq(alpha_train, y.to(device), reduction='sum')
        losses = {'BR': bay_risk * 1.0 / (n_nodes * frac_train)}

        # KL divergence w.r.t. alpha-prior from Gaussian Dirichlet Kernel
        if self.params["use_kernel"]:
            dirichlet = D.Dirichlet(alpha)
            alpha_prior = self.alpha_prior.to(alpha.device)
            dirichlet_prior = D.Dirichlet(alpha_prior)
            KL_prior = D.kl.kl_divergence(dirichlet, dirichlet_prior)
            KL_prior = loss_reduce(KL_prior, reduction='sum')
            losses['KL_prior'] = self.params["lambda_1"] * KL_prior / n_nodes
        else:
            losses['KL_prior'] = torch.tensor(0.0, device=device)

        # KL divergence for teacher training
        if self.params["teacher_training"] and self.y_teacher is not None:
            # currently only works for full-batch training
            # i.e. epochs == iterations
            if self.training:
                if self.epoch is None:
                    self.epoch = 0
                else:
                    self.epoch += 1

            y_teacher = self.y_teacher.to(prediction.soft.device)
            lambda_2 = min(1.0, self.epoch * 1.0 / 200)
            categorical_pred = D.Categorical(prediction.soft)
            categorical_teacher = D.Categorical(y_teacher)
            KL_teacher = D.kl.kl_divergence(categorical_pred, categorical_teacher)
            KL_teacher = loss_reduce(KL_teacher, reduction='sum')
            losses['KL_teacher'] = lambda_2 * KL_teacher / n_nodes
        else:
            losses['KL_teacher'] = torch.tensor(0.0, device=device)

        return losses['BR'] + losses['KL_prior'] + losses['KL_teacher']

    def create_storage(self, dataset_ind, teacher, device):
        """Initialize priors for the model"""
        self.teacher = teacher
        
        # Forward pass through the teacher model to get predictions
        x = teacher(dataset_ind, device)
        log_soft = F.log_softmax(x, dim=-1)
        soft = torch.exp(log_soft)
        self.y_teacher = soft.to(device)
        
        # Create Dirichlet prior from data using graph distances
        # Simulate GDK behavior without using storage
        num_classes = self.params["num_classes"]
        n_nodes = dataset_ind.y.size(0)
        
        # Compute distance evidence using KDE on graph
        idx_train = torch.nonzero(dataset_ind.splits['train'], as_tuple=False).squeeze().tolist()
        evidence = torch.zeros((n_nodes, num_classes), device=device)
        G = to_networkx(dataset_ind, to_undirected=True)
        
        for idx_t in idx_train:
            distances = single_source_shortest_path_length(G, source=idx_t, cutoff=10)
            distances = torch.Tensor(
                [distances[n] if n in distances else 1e10 for n in range(n_nodes)]).to(device)
            evidence[:, dataset_ind.y[idx_t]] += kernel_distance(distances, sigma=1.0).unsqueeze(1)
        
        self.alpha_prior = (1.0 + evidence).to(device)

    def detect(self, dataset, node_idx, device, cfg):
        pred = self.forward_impl(dataset, device)
        if self.detect_type == 'Alea':
            score = pred.sample_confidence_aleatoric[node_idx]
        elif self.detect_type == 'Epist':
            score = pred.sample_confidence_epistemic[node_idx]
        else:
            raise ValueError(f"Unknown detect type {self.detect_type}")

        return score

def kernel_distance(x, sigma=1.0):
    """Calculate kernel distance with RBF kernel"""
    sigma_scale = 1.0 / (sigma * math.sqrt(2 * math.pi))
    k_dis = torch.exp(-torch.square(x) / (2 * sigma * sigma))
    return sigma_scale * k_dis