import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import softmax

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        accs = [
            (pred[data.train_mask] == data.y[data.train_mask]).float().mean().item(),
            (pred[data.val_mask] == data.y[data.val_mask]).float().mean().item(),
            (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()
        ]
    return accs

class AttentionDiffusion(nn.Module):
    def __init__(self, in_dim, heads=4):
        super().__init__()
        self.heads = heads
        self.dk = in_dim // heads
        self.WQ = nn.Linear(in_dim, in_dim)
        self.WK = nn.Linear(in_dim, in_dim)
        self.linear = nn.Linear(self.dk, in_dim) # This linear layer is for projection.

    def forward(self, X, edge_index):
        N = X.size(0)
        Q = self.WQ(X).view(N, self.heads, self.dk)
        K = self.WK(X).view(N, self.heads, self.dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (self.dk ** 0.5)
        attn = softmax(scores, row)  
        attn = attn.unsqueeze(-1)  
        V = X[col].unsqueeze(1).expand(-1, self.heads, -1)  
        out = torch.zeros(N, self.heads, self.dk, device=X.device)
        out.index_add_(0, row, attn.unsqueeze(-1) * V[:, :, :self.dk])
        out = out.mean(dim=1) 
        out = self.linear(out) 
        return out - X  # Which is equal to (A - I)X
    
class LambdaDiffusion(nn.Module):
    def __init__(self, in_dim, heads=4, use_adj_mask=True):
        super().__init__()
        self.heads = heads
        self.dk = in_dim // heads
        self.use_adj_mask = use_adj_mask
        self.WQ = nn.Linear(in_dim, in_dim)
        self.WK = nn.Linear(in_dim, in_dim)
        self.linear = nn.Linear(self.dk, in_dim) 
        self.initialized = False
        self.lambda_diag_initialized = False

    def init_X0(self, X):
        if not self.initialized:
            self.X0 = X.detach().clone()
            self.initialized = torch.tensor(True)

    def init_lambda_diag(self, device):
        self.lambda_diag = nn.Parameter(torch.tensor(1.0, device=device))
        self.lambda_diag_initialized = True

    def forward(self, X, edge_index):
        self.init_X0(X)
        N = X.size(0)
        if not self.lambda_diag_initialized:
            self.init_lambda_diag(X.device)

        Q = self.WQ(X).view(N, self.heads, self.dk)
        K = self.WK(X).view(N, self.heads, self.dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (self.dk ** 0.5)
        attn = softmax(scores, row) 

        V = X[col].unsqueeze(1).expand(-1, self.heads, -1) 
        agg = torch.zeros(N, self.heads, self.dk, device=X.device)
        agg.index_add_(0, row, attn.unsqueeze(-1) * V[:, :, :self.dk])
        agg = agg.mean(dim=1) 
        agg = self.linear(agg) 

        diffusion_term = agg - X
        memory_term = self.X0 - X
        lambda_weights = torch.sigmoid(self.lambda_diag)
        dXdt = lambda_weights * diffusion_term + (1 - lambda_weights) * memory_term
        return dXdt
    

def rk4_step_fully(func, x, t, dt, step_number):
    """Runge-Kutta 4 integration step."""
    k1 = func(x, step_number, t = t)
    k2 = func(x + 0.5 * dt * k1, step_number, t = t + 0.5 * dt)
    k3 = func(x + 0.5 * dt * k2, step_number, t = t + 0.5 * dt)
    k4 = func(x + dt * k3, step_number, t = t + dt)
    return x + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)

def solve_diffusion_rk4_fully(func, x0, t_span, dt):
    """Solves ∂X/∂t = func(X,t) using RK4."""
    x = x0.clone()
    t = t_span[0]
    c = 0 # step number 
    while c < int((t_span[1] - t_span[0]) / dt): # Iterate
        x = rk4_step_fully(func, x, t, dt, step_number=c)
        c += 1
        t += dt
    return x

class GRAND_ASC(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_steps, heads, T, relation_type, input_dropout=0.0):
        super().__init__()
        self.input_dropout = nn.Dropout(input_dropout)
        self.diffusions = nn.ModuleList([
            LambdaDiffusion(hidden_dim, heads=heads) for _ in range(num_steps)
        ])
        if relation_type == 'heterophilic':
            self.relation_type = relation_type
            self.encoder = nn.Linear(in_dim, hidden_dim)
            self.decoder = nn.Linear(hidden_dim, out_dim)
            
        elif relation_type == 'homophilic':
            self.relation_type = relation_type
            self.encoder=GCNConv(in_dim, hidden_dim)
            self.decoder=GCNConv(hidden_dim, out_dim) 

        else:
            raise ValueError("Relation type not recognised.")

        self.num_steps = num_steps
        self.T = T
        self.dt = T / num_steps

        

    def forward(self, x, edge_index):
        # Encoder
        x = self.input_dropout(x)
        if self.relation_type == 'heterophilic':
            x0 = F.relu(self.encoder(x))
        elif self.relation_type == 'homophilic':
            x0 = F.relu(self.encoder(x, edge_index))
        else:
            raise ValueError("bad param")
        for diffusion in self.diffusions:
            diffusion.init_X0(x0)
        
        # Diffusion dynamics
        def diffusion_func(X, step_number, t=None):
            return self.diffusions[step_number](X, edge_index)

        x = solve_diffusion_rk4_fully(diffusion_func, x0, [0, self.T], self.dt)

        # Decoder
        if self.relation_type == 'heterophilic':
            x = self.decoder(x)
        elif self.relation_type == 'homophilic':
            x = self.decoder(x,edge_index)
        else:
            raise ValueError("bad param")
        return F.log_softmax(x, dim=1)
