import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))


class DictGNN(nn.Module):
    """GNN model for GOOD datasets with dict input format"""

    def __init__(self,
                 input_dim=8,
                 hidden_dim=128,
                 num_classes=2,
                 num_envs=2,
                 dropout=0.3):
        super().__init__()

        # Graph convolutional layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim // 2)

        # Batch norms
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim // 2)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Classifiers
        self.task_classifier = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 4, num_classes)
        )

        # Environment classifier (adversarial)
        self.env_classifier = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, num_envs)
        )

        # Readout function
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, hidden_dim // 2)
        )

    def process_graph_batch(self, x_list, edge_index_list):
        """Process a batch of graphs"""
        device = next(self.parameters()).device
        all_graph_features = []

        for x, edge_index in zip(x_list, edge_index_list):
            # Move to device
            x = x.to(device)
            edge_index = edge_index.to(device)

            # GNN layers
            h = self.conv1(x, edge_index)
            if h.size(0) > 1:
                h = self.bn1(h)
            h = F.relu(h)
            h = self.dropout(h)

            h = self.conv2(h, edge_index)
            if h.size(0) > 1:
                h = self.bn2(h)
            h = F.relu(h)
            h = self.dropout(h)

            h = self.conv3(h, edge_index)
            if h.size(0) > 1:
                h = self.bn3(h)
            h = F.relu(h)

            # Graph-level pooling (mean over nodes)
            graph_feature = h.mean(dim=0, keepdim=True)  # [1, hidden_dim//2]
            all_graph_features.append(graph_feature)

        if all_graph_features:
            return torch.cat(all_graph_features, dim=0)  # [batch_size, hidden_dim//2]
        else:
            return torch.zeros(0, self.conv3.out_channels, device=device)

    def forward(self, batch_dict, labels=None, env_labels=None, adversarial_weight=0.1):
        """
        Args:
            batch_dict: dict with keys 'x', 'edge_index', 'y', 'env'
                x: list of node feature tensors [num_nodes, features]
                edge_index: list of edge_index tensors [2, num_edges]
                y: tensor of labels [batch_size]
                env: tensor of environment labels [batch_size]
        """
        # Extract data from dictionary
        x_list = batch_dict['x']
        edge_index_list = batch_dict['edge_index']

        # Process graphs
        graph_features = self.process_graph_batch(x_list, edge_index_list)
        graph_features = self.readout(graph_features)

        # Task prediction
        task_logits = self.task_classifier(graph_features)

        # Environment prediction (on detached features)
        env_logits = self.env_classifier(graph_features.detach())

        output = {
            'task_logits': task_logits,
            'env_logits': env_logits,
            'features': graph_features
        }

        # Calculate losses if labels provided
        if labels is not None:
            # Ensure labels are correct shape
            if labels.dim() > 1:
                labels = labels.squeeze()

            # Task loss
            task_loss = F.cross_entropy(task_logits, labels)

            # Environment loss (adversarial)
            if env_labels is not None:
                if env_labels.dim() > 1:
                    env_labels = env_labels.squeeze()

                env_loss = F.cross_entropy(env_logits, env_labels)

                # Total loss = task_loss + λ * env_loss
                total_loss = task_loss + adversarial_weight * env_loss

                output.update({
                    'task_loss': task_loss,
                    'env_loss': env_loss,
                    'total_loss': total_loss
                })
            else:
                output.update({
                    'task_loss': task_loss,
                    'total_loss': task_loss
                })

            # Calculate accuracy
            with torch.no_grad():
                preds = task_logits.argmax(dim=1)
                accuracy = (preds == labels).float().mean()
                output['accuracy'] = accuracy

        return output


class SimpleGNN(nn.Module):
    """Simple GNN model"""

    def __init__(self, input_dim, hidden_dim=128, num_classes=2, num_envs=2):
        super().__init__()

        # GNN layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim // 2)

        # Batch normalization
        self.bn1 = nn.BatchNorm1d(hidden_dim)

        # Dropout
        self.dropout = nn.Dropout(0.3)

        # Main classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, num_classes)
        )

        # Environment classifier
        self.env_classifier = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, num_envs)
        )

    def extract_features(self, batch):
        """Extract features"""
        device = next(self.parameters()).device

        if isinstance(batch, dict):
            x_list = batch.get('x', [])
            edge_list = batch.get('edge_index', [])
        else:
            return torch.zeros(1, self.conv2.out_channels, device=device)

        if not x_list:
            return torch.zeros(1, self.conv2.out_channels, device=device)

        # Process first graph
        x = x_list[0].to(device) if isinstance(x_list, list) else x_list.to(device)

        if len(edge_list) > 0:
            edge_index = edge_list[0].to(device) if isinstance(edge_list, list) else edge_list.to(device)
        else:
            edge_index = None

        # First layer
        h = self.conv1(x, edge_index) if edge_index is not None else self.conv1(x, None)
        if h.size(0) > 1:
            h = self.bn1(h)
        h = F.relu(h)
        h = self.dropout(h)

        # Second layer
        h = self.conv2(h, edge_index) if edge_index is not None else self.conv2(h, None)
        h = F.relu(h)

        # Graph-level pooling
        if h.size(0) > 0:
            graph_feature = h.mean(dim=0, keepdim=True)
        else:
            graph_feature = torch.zeros(1, self.conv2.out_channels, device=device)

        return graph_feature

    def forward(self, batch, method='erm', alpha=0.1):
        """Forward propagation"""
        features = self.extract_features(batch)
        logits = self.classifier(features)

        output = {
            'logits': logits,
            'features': features
        }

        return output


# Define GNNModel for compatibility
GNNModel = SimpleGNN