import random
from sklearn.model_selection import StratifiedKFold, train_test_split
from typing import Optional, Sequence, Tuple, List
import torch
import numpy as np
from torch import Tensor
from torch_geometric.typing import OptTensor
from torch_geometric.utils import scatter, cumsum
from torch_geometric.data import Data


def get_train_val_test_datasets(
    dataset: Sequence[Data],
    seed: int,
    n_folds: int,
    fold_id: int,
    ratio: float = 0.85,
) -> Tuple[List[Data], List[Data], List[Data]]:
    """
    Returns stratified Train/Val/Test lists of PyG Data objects for a given fold.
    - Outer split: StratifiedKFold (train+val vs. test)
    - Inner split: train_test_split with stratify (train vs. val)
    
    Args:
        dataset: Sequence of PyG Data objects
        seed: Random seed for reproducibility
        n_folds: Number of folds for cross-validation
        fold_id: Which fold to use (0 to n_folds-1)
        ratio: Ratio of train+val data to use for training (vs validation)
    
    Returns:
        Tuple of (train_dataset, val_dataset, test_dataset)
    """
    y_all = np.array([int(dataset[i].y.item()) for i in range(len(dataset))], dtype=int)

    # Outer split: StratifiedKFold (train+val vs. test)
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
    all_splits = list(skf.split(np.zeros(len(dataset)), y_all))  # placeholder for X, we only care about idx
    trainval_idx, test_idx = all_splits[fold_id]

    # Inner split: train_test_split with stratify (train vs. val)
    y_trainval = y_all[trainval_idx]
    test_size = 1.0 - ratio
    train_idx_sub, val_idx_sub = train_test_split(
        trainval_idx,
        test_size=test_size,
        random_state=seed,
        shuffle=True,
        stratify=y_trainval,
    )

    train_dataset = [dataset[i] for i in train_idx_sub]
    val_dataset = [dataset[i] for i in val_idx_sub]
    test_dataset = [dataset[i] for i in test_idx]

    return train_dataset, val_dataset, test_dataset


def to_dense_rect_matrix(
    edge_index: Tensor,  # (2, nnz) -- indices of S
    row_batch: OptTensor,  # batch of input graphs
    col_batch: OptTensor,  # batch of pooled graphs
    edge_attr: OptTensor = None,  # values for S
    max_num_rows: Optional[int] = None,
    max_num_cols: Optional[int] = None,
    batch_size: Optional[int] = None,
) -> Tensor:
    if batch_size is None:
        batch_size = int(row_batch.max()) + 1 if row_batch.numel() else 1

    # how many rows/cols per graph
    one_r = row_batch.new_ones(row_batch.size(0))
    one_c = col_batch.new_ones(col_batch.size(0))
    num_r = scatter(one_r, row_batch, dim=0, dim_size=batch_size, reduce="sum")
    num_c = scatter(one_c, col_batch, dim=0, dim_size=batch_size, reduce="sum")
    cum_r = cumsum(num_r)  # prefix sums like data.ptr
    cum_c = cumsum(num_c)

    r, c = edge_index  # global indices
    g = row_batch[r]  # graph id (block-diagonal ⇒ row & col agree)
    r_loc = r - cum_r[g]
    c_loc = c - cum_c[g]

    if max_num_rows is None:
        max_num_rows = int(num_r.max())
    if max_num_cols is None:
        max_num_cols = int(num_c.max())

    mask = (r_loc < max_num_rows) & (c_loc < max_num_cols)
    g, r_loc, c_loc = g[mask], r_loc[mask], c_loc[mask]
    if edge_attr is None:
        edge_attr = torch.ones_like(r, dtype=torch.float)[mask]
    else:
        edge_attr = edge_attr[mask]

    flat_size = batch_size * max_num_rows * max_num_cols
    idx = g * max_num_rows * max_num_cols + r_loc * max_num_cols + c_loc
    dense = scatter(edge_attr, idx, dim=0, dim_size=flat_size, reduce="sum")
    return dense.view(batch_size, max_num_rows, max_num_cols)
