import torch
import random
import numpy as np
import os

def make_reproducibility(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

def update_sparse_rows_from_csr(orig_crow, orig_col, orig_val, row_indices, upd_crow, upd_col, upd_val):
    """
    Update sparse rows by combining existing and new indices/values
    """
    n_rows = orig_crow.size(0) - 1

    # 1) compute per‐row nnz counts
    orig_lengths = orig_crow[1:] - orig_crow[:-1]       # (N,)
    upd_lengths  = upd_crow[1:]  - upd_crow[:-1]        # (K,)

    # 2) build new per‐row nnz: replace lengths[row_indices] = upd_lengths
    new_lengths = orig_lengths.clone()
    new_lengths[row_indices] = upd_lengths

    # 3) build new row‐pointer
    new_crow = torch.empty_like(orig_crow)
    new_crow[0] = 0
    new_crow[1:] = torch.cumsum(new_lengths, dim=0)

    # 4) allocate new col/val storage
    total_nnz = new_crow[-1].item()
    new_col = torch.empty((total_nnz,), dtype=orig_col.dtype)
    new_val = torch.empty((total_nnz,), dtype=orig_val.dtype)

    # 5) map global row → upd_crow index
    idx_map = {int(r): i for i, r in enumerate(row_indices.tolist())}

    # 6) fill in each row
    for i in range(n_rows):
        dst_start = new_crow[i].item()
        dst_end   = new_crow[i+1].item()
        length    = dst_end - dst_start
        if length == 0:
            continue

        if i in idx_map:
            # copy from updated block
            k = idx_map[i]
            s, e = upd_crow[k].item(), upd_crow[k+1].item()
            new_col[dst_start:dst_end] = upd_col[s:e]
            new_val[dst_start:dst_end] = upd_val[s:e]
        else:
            # copy from original
            s, e = orig_crow[i].item(), orig_crow[i+1].item()
            new_col[dst_start:dst_end] = orig_col[s:e]
            new_val[dst_start:dst_end] = orig_val[s:e]

    return new_crow, new_col, new_val

def extract_sparse_rows_from_csr(crow, col, val, num_cols, row_indices, device=None) :
    """
    Extracts a subset of rows from a sparse matrix.
    """
    K = row_indices.numel()

    lengths = []
    for global_r in row_indices.tolist():
        start, end = crow[global_r].item(), crow[global_r + 1].item()
        lengths.append(end - start)

    # construct new row pointer array (crow_sub) for the extracted rows
    crow_sub = torch.zeros(K + 1, dtype=crow.dtype, device=device)
    if K > 0:
        crow_sub[1:] = torch.tensor(lengths, dtype=crow.dtype, device=device).cumsum(dim=0)

    # gather corresponding column indices and values for selected rows
    col_chunks = []
    val_chunks = []
    for idx, global_r in enumerate(row_indices.tolist()):
        start, end = crow[global_r].item(), crow[global_r + 1].item()
        if end > start:
            col_chunks.append(col[start:end].to(device))
            val_chunks.append(val[start:end].to(device))

    if col_chunks:
        col_sub = torch.cat(col_chunks, dim=0)
        val_sub = torch.cat(val_chunks, dim=0)
    else:
        col_sub = torch.empty((0,), dtype=col.dtype, device=device)
        val_sub = torch.empty((0,), dtype=val.dtype, device=device)

    # create a new sparse CSR tensor from extracted rows
    csr_sub = torch.sparse_csr_tensor(crow_sub, col_sub, val_sub,
                                      size=(K, num_cols),
                                      device=device)
    return csr_sub
