import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, SAGEConv, GATConv
from torch_geometric.nn.models import MLP
from typing import Dict, Any, Optional


def create_model_from_config(config: Dict[str, Any], num_features: int, num_classes: int) -> nn.Module:
    model_config = config['model']
    model_name = model_config['name'].lower()
    hidden_dim = model_config.get('hidden_dim', 128)
    num_layers = model_config.get('num_layers', 2)
    dropout = model_config.get('dropout', 0.5)
    activation = model_config.get('activation', 'relu')

    if model_name == 'gcn':
        return GCN(
            in_dim=num_features,
            hidden_dim=hidden_dim,
            out_dim=num_classes,
            num_layers=num_layers,
            dropout=dropout,
            activation=activation
        )
    elif model_name == 'gin':
        return GIN(
            in_dim=num_features,
            hidden_dim=hidden_dim,
            out_dim=num_classes,
            num_layers=num_layers,
            dropout=dropout,
            activation=activation
        )
    elif model_name == 'sage':
        return GraphSAGE(
            in_dim=num_features,
            hidden_dim=hidden_dim,
            out_dim=num_classes,
            num_layers=num_layers,
            dropout=dropout,
            activation=activation
        )
    elif model_name == 'gat':
        return GAT(
            in_dim=num_features,
            hidden_dim=hidden_dim,
            out_dim=num_classes,
            num_layers=num_layers,
            dropout=dropout,
            activation=activation
        )
    else:
        raise ValueError(f"Unsupported model: {model_name}")


class GraphDROModel(nn.Module):
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.config = config
        self.requires_grad_data = config.get('requires_grad', True)

    def forward(self, edge_index, x):
        if self.requires_grad_data and self.training:
            if not x.requires_grad:
                x = x.requires_grad_(True)
        return self._forward_impl(edge_index, x)

    def _forward_impl(self, edge_index, x):
        raise NotImplementedError


class GCN(GraphDROModel):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 num_layers: int = 2, dropout: float = 0.5, **kwargs):
        super().__init__(kwargs)
        self.num_layers = num_layers
        self.dropout = dropout
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(GCNConv(in_dim, hidden_dim))
        self.bns.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        if num_layers > 1:
            self.convs.append(GCNConv(hidden_dim, out_dim))
        else:
            self.convs[0] = GCNConv(in_dim, out_dim)

    def _forward_impl(self, edge_index, x):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            if i < len(self.bns):
                x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x


class GIN(GraphDROModel):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 num_layers: int = 2, dropout: float = 0.5, eps: float = 0.0, **kwargs):
        super().__init__(kwargs)
        self.num_layers = num_layers
        self.dropout = dropout
        self.eps = eps
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                mlp = MLP([in_dim, hidden_dim, hidden_dim])
            elif i == num_layers - 1:
                mlp = MLP([hidden_dim, hidden_dim, out_dim])
            else:
                mlp = MLP([hidden_dim, hidden_dim, hidden_dim])
            self.convs.append(GINConv(mlp, eps=eps))

    def _forward_impl(self, edge_index, x):
        for conv in self.convs:
            if hasattr(conv, 'eps') and conv.eps.device != x.device:
                conv.eps = conv.eps.to(x.device)
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x


class GraphSAGE(GraphDROModel):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 num_layers: int = 2, dropout: float = 0.5, **kwargs):
        super().__init__(kwargs)
        self.num_layers = num_layers
        self.dropout = dropout
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SAGEConv((in_dim, in_dim), hidden_dim))
        self.bns.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv((hidden_dim, hidden_dim), hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        if num_layers > 1:
            self.convs.append(SAGEConv((hidden_dim, hidden_dim), out_dim))
        else:
            self.convs = nn.ModuleList([SAGEConv((in_dim, in_dim), out_dim)])
            self.bns = nn.ModuleList()

    def _forward_impl(self, edge_index, x):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            if i < len(self.bns):
                x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x


class GAT(GraphDROModel):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 num_layers: int = 2, dropout: float = 0.5, heads: int = 8, **kwargs):
        super().__init__(kwargs)
        self.num_layers = num_layers
        self.dropout = dropout
        self.heads = heads
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(in_dim, hidden_dim, heads=heads, dropout=dropout))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim * heads, hidden_dim,
                                    heads=heads, dropout=dropout))
        if num_layers > 1:
            self.convs.append(GATConv(hidden_dim * heads, out_dim,
                                    heads=1, dropout=dropout))
        else:
            self.convs[0] = GATConv(in_dim, out_dim, heads=1, dropout=dropout)

    def _forward_impl(self, edge_index, x):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x


class MLP_Model(GraphDROModel):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 num_layers: int = 2, dropout: float = 0.5, **kwargs):
        super().__init__(kwargs)
        self.num_layers = num_layers
        self.dropout = dropout
        layers = []
        layers.append(nn.Linear(in_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(hidden_dim, out_dim))
        self.mlp = nn.Sequential(*layers)

    def _forward_impl(self, edge_index, x):
        return self.mlp(x)


class FairGNN(GraphDROModel):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int,
                 backbone: str = 'gcn', sens_dim: int = 1, **kwargs):
        super().__init__(kwargs)
        self.sens_dim = sens_dim
        if backbone == 'gcn':
            self.encoder = GCN(in_dim, hidden_dim, hidden_dim, **kwargs)
        elif backbone == 'gin':
            self.encoder = GIN(in_dim, hidden_dim, hidden_dim, **kwargs)
        elif backbone == 'sage':
            self.encoder = GraphSAGE(in_dim, hidden_dim, hidden_dim, **kwargs)
        elif backbone == 'gat':
            self.encoder = GAT(in_dim, hidden_dim, hidden_dim, **kwargs)
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
        self.classifier = nn.Linear(hidden_dim, out_dim)
        self.sens_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, sens_dim)
        )

    def _forward_impl(self, edge_index, x):
        h = self.encoder._forward_impl(edge_index, x)
        logits = self.classifier(h)
        sens_logits = self.sens_predictor(h)
        return logits, sens_logits

    def forward(self, edge_index, x, return_sens=False):
        if self.requires_grad_data and self.training:
            if not x.requires_grad:
                x = x.requires_grad_(True)
        result = self._forward_impl(edge_index, x)
        if isinstance(result, tuple) and len(result) == 2:
            logits, sens_logits = result
            if return_sens:
                return logits, sens_logits
            else:
                return logits
        else:
            return result


def build_model(name: str, in_dim: int, hidden_dim: int, out_dim: int,
                config: Optional[Dict[str, Any]] = None):
    if config is None:
        config = {}
    num_layers = config.get('num_layers', 2)
    dropout = config.get('dropout', 0.5)
    requires_grad = config.get('requires_grad', True)
    model_config = {
        'num_layers': num_layers,
        'dropout': dropout,
        'requires_grad': requires_grad
    }
    if name.lower() == "gcn":
        return GCN(in_dim, hidden_dim, out_dim, **model_config)
    elif name.lower() == "gin":
        eps = config.get('eps', 0.0)
        model_config['eps'] = eps
        return GIN(in_dim, hidden_dim, out_dim, **model_config)
    elif name.lower() == "sage":
        return GraphSAGE(in_dim, hidden_dim, out_dim, **model_config)
    elif name.lower() == "gat":
        heads = config.get('heads', 8)
        model_config['heads'] = heads
        return GAT(in_dim, hidden_dim, out_dim, **model_config)
    elif name.lower() == "mlp":
        return MLP_Model(in_dim, hidden_dim, out_dim, **model_config)
    elif name.lower() == "fairgnn":
        backbone = config.get('backbone', 'gcn')
        sens_dim = config.get('sens_dim', 1)
        return FairGNN(in_dim, hidden_dim, out_dim, backbone=backbone,
                      sens_dim=sens_dim, **model_config)
    else:
        raise ValueError(f"Unknown model: {name}")


def create_model_from_config(config: Dict[str, Any], in_dim: int, out_dim: int):
    model_config = config.get('model', {})
    data_config = config.get('data', {})
    full_config = {**model_config, **data_config}
    model_name = model_config.get('name', 'gcn')
    hidden_dim = model_config.get('hidden_dim', 128)
    return build_model(model_name, in_dim, hidden_dim, out_dim, full_config)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model_info(model):
    total_params = count_parameters(model)
    return {
        'model_class': model.__class__.__name__,
        'total_parameters': total_params,
        'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad),
        'model_size_mb': total_params * 4 / (1024 * 1024)
    }