import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
import torch_geometric as pyg
from torch_geometric.data import Data, Batch
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from tqdm import tqdm

class GraphSAGEUnifiedModel(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers, hidden_channels=None, bias=True, dropout=0.0,
                 graph_level_task=False):
        super(GraphSAGEUnifiedModel, self).__init__()

        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.graph_level_task = graph_level_task
        self.graph_convs = []
        self.graph_convs.append(SAGEConv(in_channels=in_channels,
                                          out_channels=hidden_channels,
                                          bias=bias))
        for l in range(1, num_layers):
            self.graph_convs.append(SAGEConv(in_channels=hidden_channels,
                                              out_channels=hidden_channels,
                                              bias=bias))

        self.graph_convs = nn.ModuleList(self.graph_convs)
        if graph_level_task: 
            self.pool = global_mean_pool
        self.readout = nn.Linear(hidden_channels, out_channels, bias=bias)
        self.activation = nn.LeakyReLU()

    # def init_params(self):
    #     for name, param in self.named_parameters():
    #         if 'weight' in name:
    #             nn.init.xavier_normal_(param, gain=1.0)
    #         elif 'bias' in name:
    #             nn.init.constant_(param, 0)

    def init_params(self):
        for name, param in self.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='leaky_relu')
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, inputs):
        x, edge_index, batch = inputs.x, inputs.edge_index, inputs.batch
        h = x.clone()

        for l in range(self.num_layers):
            h = self.graph_convs[l](x=h, edge_index=edge_index)
            h = self.activation(h)
            h = F.dropout(h, p=self.dropout, training=self.training)

        if self.graph_level_task:
            h = self.pool(h, batch)

        y = self.readout(h)

        return y


class P12_GraphSAGEUnified_Baseline(nn.Module):
    def __init__(self, 
                 input_dim,
                 hidden_channels=64,
                 num_layers=3,
                 num_classes=1,
                 dropout=0.1):
        super(P12_GraphSAGEUnified_Baseline, self).__init__()
        
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.input_dim = input_dim
        
        # Create the unified GraphSAGE model
        self.unified_graph_sage = GraphSAGEUnifiedModel(
            in_channels=input_dim,
            out_channels=hidden_channels,
            num_layers=num_layers,
            hidden_channels=hidden_channels,
            dropout=dropout,
            graph_level_task=True
        )
        
        # Create the classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels // 2, num_classes)
        )
        
    def _create_unified_edge_index(self, graph_dict):
        edges = []
        batch = []
        node_offset = 0
        
        for biomarker_name, (node_features, adj_matrix, _) in graph_dict.items():
            features = node_features
            if features is None:
                continue
            if features.dim() == 3 and features.size(0) == 1:
                features = features.squeeze(0)
            elif features.dim() > 2:
                features = features.view(features.shape[-2], features.shape[-1])
            elif features.dim() == 1:
                features = features.unsqueeze(0)

            if features.size(0) > 0:  # Only process non-empty graphs
                num_nodes = features.size(0)
                
                # Create directed path edges for this biomarker
                for i in range(num_nodes - 1):
                    # Directed edge: (node_offset + i) -> (node_offset + i + 1)
                    edges.append([node_offset + i, node_offset + i + 1])
                
                # All nodes belong to the same patient graph (batch index 0)
                batch.extend([0] * num_nodes)
                
                node_offset += num_nodes
        
        if len(edges) == 0:
            # Handle case with no valid edges (e.g., all graphs have single nodes)
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
            batch = torch.tensor([0], dtype=torch.long)
        else:
            edge_index = torch.tensor(edges, dtype=torch.long).t()
            batch = torch.tensor(batch, dtype=torch.long)
        
        return edge_index, batch
    
    def _create_unified_node_features(self, graph_dict):
        node_features_list = []
        
        for biomarker_name, (node_features, adj_matrix, _) in graph_dict.items():
            features = node_features
            if features is None:
                continue
            if features.dim() == 3 and features.size(0) == 1:
                features = features.squeeze(0)
            elif features.dim() > 2:
                features = features.view(features.shape[-2], features.shape[-1])
            elif features.dim() == 1:
                features = features.unsqueeze(0)
            if features.size(0) > 0:
                node_features_list.append(features)
        
        if len(node_features_list) == 0:
            # Handle case with no valid biomarkers
            device = next(self.parameters()).device if list(self.parameters()) else torch.device('cpu')
            return torch.zeros(1, self.input_dim, device=device)
        
        # Concatenate all node features
        unified_features = torch.cat(node_features_list, dim=0)
        return unified_features
    
    def forward(self, graph_dict):
        # Create unified node features and edge index
        unified_node_features = self._create_unified_node_features(graph_dict)
        unified_edge_index, batch = self._create_unified_edge_index(graph_dict)
        
        # Create PyTorch Geometric Data object and move to model device
        data = Data(x=unified_node_features, edge_index=unified_edge_index, batch=batch)
        device = next(self.parameters()).device if any(p is not None for p in self.parameters()) else torch.device('cpu')
        data = data.to(device)

        # Apply unified GraphSAGE model
        graph_repr = self.unified_graph_sage(data)
        
        # Final classification
        output = self.classifier(graph_repr.squeeze())
        
        return output


def train_graphsage_unified_baseline(model, train_loader, val_loader, test_loader, 
                                    num_epochs=50, lr=0.001, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    
    best_val_auc = 0.0
    best_model_state = None
    
    print(f"Training Unified GraphSAGE baseline for {num_epochs} epochs...")
    
    for epoch in tqdm(range(num_epochs)):
        # Training
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels = []
        
        for batch_idx, (graph_dict, label, _) in tqdm(enumerate(train_loader)):
            label = label.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            output = model(graph_dict)
            
            target = label.float().view_as(output)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_preds.append(torch.sigmoid(output).detach().cpu())
            train_labels.append(label.cpu())
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for graph_dict, label, _ in val_loader:
                label = label.to(device)
                
                output = model(graph_dict)
                
                target = label.float().view_as(output)
                loss = criterion(output, target)
                val_loss += loss.item()
                val_preds.append(torch.sigmoid(output).cpu())
                val_labels.append(label.cpu())
        
        # Calculate metrics
        train_preds = torch.stack(train_preds).numpy().squeeze()
        train_labels = torch.stack(train_labels).numpy().squeeze()
        val_preds = torch.stack(val_preds).numpy().squeeze()
        val_labels = torch.stack(val_labels).numpy().squeeze()
        
        train_auc = roc_auc_score(train_labels, train_preds)
        val_auc = roc_auc_score(val_labels, val_preds)
        train_auprc = average_precision_score(train_labels, train_preds)
        val_auprc = average_precision_score(val_labels, val_preds)
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs}: "
                f"Train Loss: {train_loss/len(train_loader):.4f}, "
                f"Train AUC: {train_auc:.4f}, Train AUPRC: {train_auprc:.4f}, "
                f"Val Loss: {val_loss/len(val_loader):.4f}, "
                f"Val AUC: {val_auc:.4f}, Val AUPRC: {val_auprc:.4f}")
        
        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_state = model.state_dict().copy()

        # with open("GNN_baselines/results/graph_sage_seed_4.txt", "a") as f:
        #     f.write(log_line + "\n")
    
    # Load best model for testing
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    # Test evaluation
    test_auc, test_auprc, test_accuracy = evaluate_graphsage_unified_model(model, test_loader, device)
    # with open("GNN_baselines/results/graph_sage_seed_4.txt", "a") as f:
    #         f.write(f"test performance: {test_auc}" + "\n\n\n")
    
    return {
        'best_val_auc': best_val_auc,
        'test_auc': test_auc,
        'test_auprc': test_auprc,
        'test_accuracy': test_accuracy
    }


def evaluate_graphsage_unified_model(model, data_loader, device):
    model.eval()
    predictions = []
    labels = []
    
    with torch.no_grad():
        for graph_dict, label, _ in data_loader:
            label = label.to(device)
            
            output = model(graph_dict)
            
            probs = torch.sigmoid(output).cpu().numpy()
            if probs.ndim == 0:
                predictions.append(probs.item())
            else:
                predictions.extend(probs)
            
            label_np = label.cpu().numpy()
            if label_np.ndim == 0:
                labels.append(label_np.item())
            else:
                labels.extend(label_np)
    
    # Calculate metrics
    auc = roc_auc_score(labels, predictions)
    auprc = average_precision_score(labels, predictions)
    accuracy = accuracy_score(labels, [1 if p > 0.5 else 0 for p in predictions])
    
    return auc, auprc, accuracy


def run_graphsage_unified_baseline(train_dataset, val_dataset, test_dataset, model_params):
    input_dim = model_params["input_dim"]
    hidden_channels = model_params["hidden_dim"]
    num_layers = model_params["num_layers"]
    num_epochs = model_params['epochs']
    lr = model_params['lr']
    seed = model_params['seed']
    device = 'cuda'
    
    print(f"Running Unified GraphSAGE baseline with:")
    print(f"  - Hidden channels: {hidden_channels}")
    print(f"  - Number of layers: {num_layers}")
    print(f"  - Input dimension: {input_dim}")
    print(f"  - Epochs: {num_epochs}")
    print(f"  - Learning rate: {lr}")
    print(f"  - Device: {device}")
    print(f"  - Seed: {seed}")
    
    # Create model
    model = P12_GraphSAGEUnified_Baseline(
        input_dim=input_dim,
        hidden_channels=hidden_channels,
        num_layers=num_layers,
        num_classes=1,
        dropout=0.1
    )
    print("Initializing data loaders...")

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1, 
        shuffle=True,
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1, shuffle=False
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1, shuffle=False
    )

    print("Data loaders initialized")

    print("Training model...")
    # Train model
    results = train_graphsage_unified_baseline(
        model, train_loader, val_loader, test_loader,
        num_epochs=num_epochs, lr=lr, device=device
    )
    
    print(f"\nUnified GraphSAGE Baseline Results:")
    print(f"  - Test Accuracy: {results['test_accuracy']:.4f}")
    
    return results 