import torch.nn as nn
import torch
from omegaconf import DictConfig, OmegaConf
from haipr.models.module import HAIPRModule
from haipr.data import HAIPRData
from haipr.utils import loss_funcs
from typing import Dict, Any, List, Optional, Tuple
import numpy as np
import logging
import lightning.pytorch as pl
from haipr.data import build_graph_from_coords, extract_calpha_coords


try:
    from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool, global_add_pool
    from torch_geometric.data import Data, Batch
    TORCH_GEOMETRIC_AVAILABLE = True
except ImportError:
    TORCH_GEOMETRIC_AVAILABLE = False

    GCNConv = None
    GATConv = None
    global_mean_pool = None
    global_max_pool = None
    global_add_pool = None
    Data = None
    Batch = None

import biotite.structure as struc
import biotite.structure.io.pdb as pdb

logger = logging.getLogger(__name__)


class MLPGNN(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int = 1,
        num_layers: int = 2,
        hidden_dim: int = 128,
        dropout: float = 0.1,
        pooling: str = "mean",
        activation: nn.Module = nn.ReLU(),
    ):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.pooling = pooling
        self.activation = activation
        self.output_dim = output_dim

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.Dropout(dropout))

        if pooling == "mean":
            self.pool = global_mean_pool
        elif pooling == "max":
            self.pool = global_max_pool
        elif pooling == "sum":
            self.pool = global_add_pool
        else:
            raise ValueError(f"Unknown pooling method: {pooling}")

        self.pred_head = nn.Linear(hidden_dim, output_dim)

    def forward(self, batch: Batch) -> torch.Tensor:
        x = batch.x
        batch_idx = batch.batch

        x = self.input_proj(x)
        x = self.activation(x)

        for i in range(0, len(self.layers), 2):
            x = self.layers[i](x)
            if i + 1 < len(self.layers):
                x = self.layers[i + 1](x)
            x = self.activation(x)

        graph_emb = self.pool(x, batch_idx)

        out = self.pred_head(graph_emb)

        # if out.shape[0] == 1:
        #     out = out.squeeze(0)
        # elif out.shape[0] != self.output_dim:
        #     raise ValueError(
        #         f"Output shape {out.shape} does not match output_dim {self.output_dim}")

        return out


class GCNModel(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        num_layers: int = 2,
        hidden_dim: int = 128,
        dropout: float = 0.1,
        pooling: str = "mean",
        activation: nn.Module = nn.ReLU(),
    ):
        if not TORCH_GEOMETRIC_AVAILABLE:
            raise ImportError(
                "torch_geometric is required for GCNModel. "
                "Install it with: pip install torch-geometric"
            )
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.pooling = pooling
        self.activation = activation

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(input_dim, hidden_dim))

        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))

        self.dropout = nn.Dropout(dropout)

        if pooling == "mean":
            self.pool = global_mean_pool
        elif pooling == "max":
            self.pool = global_max_pool
        elif pooling == "sum":
            self.pool = global_add_pool
        else:
            raise ValueError(f"Unknown pooling method: {pooling}")

        self.pred_head = nn.Linear(hidden_dim, output_dim)

    def forward(self, batch: Batch) -> torch.Tensor:
        x = batch.x
        edge_index = batch.edge_index
        batch_idx = batch.batch

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = self.activation(x)
                x = self.dropout(x)

        graph_emb = self.pool(x, batch_idx)

        out = self.pred_head(graph_emb)

        return out


class GATModel(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        num_layers: int = 2,
        hidden_dim: int = 128,
        num_heads: int = 4,
        dropout: float = 0.1,
        pooling: str = "mean",
        activation: nn.Module = nn.ReLU(),
    ):
        if not TORCH_GEOMETRIC_AVAILABLE:
            raise ImportError(
                "torch_geometric is required for GATModel. "
                "Install it with: pip install torch-geometric"
            )
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.pooling = pooling
        self.activation = activation

        self.convs = nn.ModuleList()

        self.convs.append(GATConv(input_dim, hidden_dim,
                          heads=num_heads, dropout=dropout, concat=True))

        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim * num_heads, hidden_dim,
                              heads=num_heads, dropout=dropout, concat=True))

        if num_layers > 1:
            self.convs.append(GATConv(hidden_dim * num_heads,
                              hidden_dim, heads=1, dropout=dropout, concat=False))

        self.dropout = nn.Dropout(dropout)

        if pooling == "mean":
            self.pool = global_mean_pool
        elif pooling == "max":
            self.pool = global_max_pool
        elif pooling == "sum":
            self.pool = global_add_pool
        else:
            raise ValueError(f"Unknown pooling method: {pooling}")

        self.pred_head = nn.Linear(hidden_dim, output_dim)

    def forward(self, batch: Batch) -> torch.Tensor:
        x = batch.x
        edge_index = batch.edge_index
        batch_idx = batch.batch

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = self.activation(x)
                x = self.dropout(x)

        graph_emb = self.pool(x, batch_idx)

        out = self.pred_head(graph_emb)

        return out


class GNNPredictor(HAIPRModule):

    def __init__(
        self,
        model_type: str = "gcn",
        input_dim: Optional[int] = None,
        output_dim: int = 1,
        num_layers: int = 2,
        hidden_dim: int = 128,
        dropout: float = 0.1,
        num_heads: int = 4,
        pooling: str = "mean",
        distance_cutoff: float = 10.0,
        num_classes: int = 0,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 32,
        loss: str = "mse",
        **kwargs,
    ):
        # Create criterion first
        criterion = loss_funcs.get(loss, nn.MSELoss())
        if criterion is None:
            logger.warning(
                f"Loss function '{loss}' not found. Defaulting to MSELoss.")
            criterion = nn.MSELoss()

        # Create model - use placeholder if input_dim not provided yet
        if input_dim is not None:
            if model_type == "mlp_gnn":
                model = MLPGNN(
                    input_dim=input_dim,
                    output_dim=output_dim,
                    num_layers=num_layers,
                    hidden_dim=hidden_dim,
                    dropout=dropout,
                    pooling=pooling,
                )
            elif model_type == "gcn":
                model = GCNModel(
                    input_dim=input_dim,
                    output_dim=output_dim,
                    num_layers=num_layers,
                    hidden_dim=hidden_dim,
                    dropout=dropout,
                    pooling=pooling,
                )
            elif model_type == "gat":
                model = GATModel(
                    input_dim=input_dim,
                    output_dim=output_dim,
                    num_layers=num_layers,
                    hidden_dim=hidden_dim,
                    num_heads=num_heads,
                    dropout=dropout,
                    pooling=pooling,
                )
            else:
                raise ValueError(f"Unknown model_type: {model_type}")
        else:
            # Create a minimal placeholder model that will be replaced in setup_model
            # Use a simple linear layer as placeholder
            model = nn.Linear(1, output_dim)

        # Initialize HAIPRModule
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            batch_size=batch_size,
            **kwargs,
        )

        # Store attributes
        self.model_type = model_type
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.num_heads = num_heads
        self.pooling = pooling
        self.distance_cutoff = distance_cutoff
        self.loss = loss
        self.embedder_instance = None
        self.embedder_config = None
        self.embedding_manager = None
        self.pdb_path = None
        self.data = None
        self.cfg = None

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        self.data = data
        self.cfg = cfg

        if hasattr(data, "pdb") and data.pdb:
            self.pdb_path = data.pdb
        elif hasattr(cfg, "benchmark") and hasattr(cfg.benchmark, "pdb"):
            self.pdb_path = cfg.benchmark.pdb

        if self.pdb_path is None:
            raise ValueError(
                "PDB path is required for GNN models. Set it in config.benchmark.pdb or data.pdb")

        if cfg and hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
            if not hasattr(self, "embedding_manager") or self.embedding_manager is None:
                self._initialize_embedder()

        # If model is placeholder (input_dim was None in __init__), replace it with real model
        is_placeholder = isinstance(
            self.model, nn.Linear) and self.model.in_features == 1

        if is_placeholder or self.input_dim is None:
            # Try to infer input_dim from data
            if data.features_loaded and self.input_dim is None:
                X, _ = data[0:1]
                if isinstance(X, torch.Tensor):
                    if len(X.shape) == 2:
                        self.input_dim = int(X.shape[-1])
                    else:
                        self.input_dim = int(X.shape[-1])
                elif isinstance(X, np.ndarray):
                    if len(X.shape) == 2:
                        self.input_dim = int(X.shape[-1])
                    else:
                        self.input_dim = int(
                            X.shape[-1] if len(X.shape) > 1 else len(X))
                else:
                    self.input_dim = len(X) if hasattr(X, '__len__') else None
                logger.info(f"Inferred input_dim={self.input_dim} from data")

            # Create model if we now have input_dim
            if self.input_dim is not None:
                if self.model_type == "mlp_gnn":
                    model = MLPGNN(
                        input_dim=self.input_dim,
                        output_dim=self.output_dim,
                        num_layers=self.num_layers,
                        hidden_dim=self.hidden_dim,
                        dropout=self.dropout,
                        pooling=self.pooling,
                    )
                elif self.model_type == "gcn":
                    model = GCNModel(
                        input_dim=self.input_dim,
                        output_dim=self.output_dim,
                        num_layers=self.num_layers,
                        hidden_dim=self.hidden_dim,
                        dropout=self.dropout,
                        pooling=self.pooling,
                    )
                elif self.model_type == "gat":
                    model = GATModel(
                        input_dim=self.input_dim,
                        output_dim=self.output_dim,
                        num_layers=self.num_layers,
                        hidden_dim=self.hidden_dim,
                        num_heads=self.num_heads,
                        dropout=self.dropout,
                        pooling=self.pooling,
                    )
                else:
                    raise ValueError(f"Unknown model_type: {self.model_type}")

                # Replace the placeholder model
                self.model = model
                logger.info(
                    f"{self.model_type.upper()} model initialized with input_dim={self.input_dim}")
            else:
                raise ValueError(
                    "Could not determine input_dim. Either set it explicitly in __init__, "
                    "initialize embedder, or ensure features are loaded.")

    def _initialize_embedder(self):
        if not self.embedder_config:
            logger.warning(
                "No embedder configuration found. Cannot initialize embedder."
            )
            return

        if self.embedder_config.name != "protenc":
            raise NotImplementedError(
                f"Embedder '{self.embedder_config.name}' not supported. Only 'protenc' is implemented."
            )

        try:
            import protenc

            model_name = self.embedder_config.model
            batch_size = getattr(self.embedder_config, "batch_size", 32)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            data_parallel = getattr(
                self.embedder_config, "data_parallel", False)

            logger.info(f"Initializing protenc embedder: {model_name}")

            try:
                model_info = protenc.get_model_info(model_name)
                embed_dim = model_info.get("embed_dim")
                if embed_dim is not None and self.input_dim is None:
                    self.input_dim = embed_dim
                    logger.info(
                        f"Set input_dim={self.input_dim} from protenc model info")
            except Exception as e:
                logger.warning(
                    f"Could not get model info for embedding dimension: {e}")

            self.embedder_instance = protenc.get_encoder(
                model_name,
                device=device,
                batch_size=batch_size,
                data_parallel=data_parallel,
            )

            logger.info("Protenc embedder initialized successfully")

            if self.input_dim is None:
                try:

                    dummy_seq = "M"
                    for embed_output in self.embedder_instance(
                        [dummy_seq],
                        average_sequence=False,
                        return_format="numpy",
                    ):
                        if isinstance(embed_output, np.ndarray):
                            self.input_dim = int(embed_output.shape[-1])
                            logger.info(
                                f"Inferred input_dim={self.input_dim} from sample embedding")
                            break
                except Exception as e:
                    logger.warning(
                        f"Could not infer input_dim from sample embedding: {e}")

        except Exception as e:
            logger.error(f"Failed to initialize protenc embedder: {e}")
            raise ValueError(f"Could not initialize protenc embedder: {e}")

    def _build_graphs_from_sequences(
        self,
        sequences: List[str],
        embeddings: np.ndarray | List[np.ndarray],
    ) -> List[Data]:
        if self.pdb_path is None:
            raise ValueError("PDB path is required for graph construction")
        try:
            coords = extract_calpha_coords(self.pdb_path)
        except Exception as e:
            logger.error(f"Failed to extract C-alpha coordinates: {e}")
            raise

        coords_list = [coords for _ in sequences]
        return self._build_graphs_from_coords(coords_list, embeddings)

    def _build_graphs_from_coords(
        self,
        coords_list: List[np.ndarray],
        embeddings: np.ndarray | List[np.ndarray],
    ) -> List[Data]:
        graphs: List[Data] = []

        if isinstance(embeddings, list):
            embedding_list = embeddings
        else:
            embedding_list = []
            for i, coords in enumerate(coords_list):
                seq_len = len(coords)
                if len(embeddings.shape) == 3:
                    embedding_list.append(embeddings[i, :seq_len])
                else:
                    embedding_list.append(embeddings[i])

        for i, (coords, emb) in enumerate(zip(coords_list, embedding_list)):
            if not isinstance(coords, np.ndarray):
                coords = np.asarray(coords)

            if coords.ndim != 2 or coords.shape[1] != 3:
                raise ValueError(
                    f"Sample {i}: coordinates must have shape (n_residues, 3), "
                    f"got {coords.shape}"
                )

            seq_len = coords.shape[0]

            if len(emb) != seq_len:
                raise ValueError(
                    f"Sample {i}: Embedding length ({len(emb)}) doesn't match "
                    f"coordinate length ({seq_len}). Cannot build graph with "
                    "mismatched dimensions."
                )

            node_features = torch.tensor(emb, dtype=torch.float32)

            graph = build_graph_from_coords(
                coords,
                distance_cutoff=self.distance_cutoff,
                node_features=node_features,
            )

            graphs.append(graph)

        return graphs

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, Any]:

        if not dataset.features_loaded or dataset.features is None:
            raise RuntimeError(
                "Dataset features not loaded. Call dataset.prepare_features() first."
            )

        if len(dataset.features.shape) == 2:
            raise ValueError(
                "Dataset features appear to be averaged (2D). "
                "GNN models require per-residue embeddings (3D). "
                "Ensure embedder.average_sequence=False when preparing features."
            )
        elif len(dataset.features.shape) == 3:

            embeddings = dataset.features[indices]
        else:
            raise ValueError(
                f"Unexpected features shape: {dataset.features.shape}. "
                "Expected 3D array (n_samples, n_residues, embedding_dim)."
            )

        chain_map = dataset.get_pdb_coords()

        flat_coords: List[np.ndarray] = []
        if isinstance(chain_map, dict) and chain_map:
            for cid in sorted(chain_map.keys()):
                coords = chain_map[cid]
                if not isinstance(coords, np.ndarray):
                    coords = np.asarray(coords)
                if coords.ndim != 2 or coords.shape[1] != 3:
                    logger.warning(
                        f"Chain {cid}: coordinates must have shape (n_residues, 3), "
                        f"got {coords.shape}. Skipping chain."
                    )
                    continue
                flat_coords.append(coords)

        if flat_coords:
            shared_coords = np.vstack(flat_coords)
        else:

            if self.pdb_path is None:
                raise ValueError(
                    "PDB path is required for graph construction but is not set."
                )
            try:
                shared_coords = extract_calpha_coords(self.pdb_path)
                logger.warning(
                    "Using unfiltered shared PDB coordinates because "
                    "get_pdb_choords() returned no valid chain coordinates."
                )
            except Exception as e:
                logger.error(
                    f"Failed to extract shared C-alpha coordinates: {e}")
                raise

        coords_list: List[np.ndarray] = [shared_coords for _ in indices]

        graphs = self._build_graphs_from_coords(coords_list, embeddings)

        labels = dataset.get_labels(indices)
        labels_tensor = torch.tensor(labels, dtype=torch.float32)

        if self.num_classes == 0 and len(labels_tensor.shape) == 1:
            labels_tensor = labels_tensor.unsqueeze(1)

        return {"graphs": graphs, "labels": labels_tensor}

    def _get_per_residue_embeddings(self, sequences: List[str]) -> List[np.ndarray]:
        if hasattr(self, "embedding_manager") and self.embedding_manager is not None:
            logger.debug(
                f"Using EmbeddingManager for {len(sequences)} sequences")
            embeddings = self.embedding_manager.get_embeddings(sequences)

            if isinstance(embeddings, np.ndarray):
                return [embeddings[i] for i in range(len(sequences))]
            return embeddings

        if not hasattr(self, "embedder_instance") or self.embedder_instance is None:
            if hasattr(self, "embedder_config") and self.embedder_config:
                self._initialize_embedder()
            else:
                raise RuntimeError(
                    "No embedding source available. Need either embedding_manager "
                    "(from inference) or embedder_config (for local embedder)."
                )

        logger.debug(f"Using local embedder for {len(sequences)} sequences")
        embeddings_list = []

        if self.embedder_instance is not None:
            for embed_output in self.embedder_instance(
                sequences,
                average_sequence=False,
                return_format="numpy",
            ):
                embeddings_list.append(embed_output)

        return embeddings_list

    def prepare_batch_features(
        self, batch_items: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        sequences = [item["sequence"] for item in batch_items]
        embeddings = self._get_per_residue_embeddings(sequences)

        graphs = self._build_graphs_from_sequences(sequences, embeddings)
        return {"inputs": {"graphs": graphs}}

    def forward(self, batch_for_predictor: Dict[str, Any]):
        batch = batch_for_predictor
        if "inputs" in batch:
            graphs = batch["inputs"]["graphs"]
        else:
            graphs = batch.get("graphs", batch)

        if isinstance(graphs, list):
            batched_graph = Batch.from_data_list(graphs)
        else:
            batched_graph = graphs

        batched_graph = batched_graph.to(self.device)

        return self.model(batched_graph)

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: Any,
        val_indices: Any,
        trainer_instance: Optional[pl.Trainer] = None,
        cfg: Optional[DictConfig] = None,
    ) -> Dict[str, Any]:
        if trainer_instance is None:
            trainer_kwargs = OmegaConf.to_container(cfg.trainer)
            trainer_instance = pl.Trainer(**trainer_kwargs)

        self.data = dataset

        train_features = self.prepare_training_features(dataset, train_indices)
        val_features = self.prepare_training_features(dataset, val_indices)

        from torch.utils.data import Dataset

        class GraphDataset(Dataset):
            def __init__(self, graphs, labels):
                self.graphs = graphs
                self.labels = labels

            def __len__(self):
                return len(self.graphs)

            def __getitem__(self, idx):
                return {"graphs": [self.graphs[idx]], "labels": self.labels[idx]}

        train_dataset = GraphDataset(
            train_features["graphs"], train_features["labels"])
        val_dataset = GraphDataset(
            val_features["graphs"], val_features["labels"])

        from torch.utils.data import DataLoader

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self._collate_graphs,
            num_workers=self.train_num_workers,
            pin_memory=True,
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self._collate_graphs,
            num_workers=self.val_num_workers,
            pin_memory=True,
        )

        trainer_instance.fit(self, train_loader, val_loader)

        predictions = self.best_val_predictions
        metrics = self.best_val_metrics

        if predictions is None:
            raise RuntimeError("No predictions available after training")

        pred_dict = {
            "indices": (
                val_indices.tolist()
                if hasattr(val_indices, "tolist")
                else list(val_indices)
            ),
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }
        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()

        return {"metrics": metrics, "predictions": pred_dict}

    def _collate_graphs(self, batch):
        graphs = []
        labels = []

        for item in batch:
            if isinstance(item, dict):
                graphs.extend(item["graphs"])
                labels.append(item["labels"])
            else:
                graphs.append(item["graphs"][0])
                labels.append(item["labels"])

        if TORCH_GEOMETRIC_AVAILABLE and Batch is not None:
            batched_graph = Batch.from_data_list(graphs)
            labels_tensor = torch.stack(labels) if len(
                labels) > 0 else torch.tensor([])
        else:
            batched_graph = graphs
            labels_tensor = labels

        return {"inputs": {"graphs": batched_graph}, "labels": labels_tensor}

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> Dict[str, Any]:

        return super().predict_sequences(sequences, params)

    def load_context(self, context):

        super().load_context(context)

        if hasattr(self, "cfg") and hasattr(self.cfg, "embedder"):
            self.embedder_config = self.cfg.embedder
            logger.info("Set embedder config from loaded config")

    def save_model(self, save_dir: str) -> str:
        import os
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, "model.pt")
        torch.save(self.model.state_dict(), save_path)
        logger.info(f"Saved GNN model to {save_path}")
        return save_path
