import pathlib
import random
import subprocess

import matplotlib.pyplot as plt
import numpy as np
import torch


def set_seed(seed: int) -> None:
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def show_sample_images(data_loader: torch.utils.data.DataLoader, nrows: int=2, ncols: int=5) -> None:
    fig, axes = plt.subplots(nrows, ncols, figsize=(2*ncols, 2*nrows))
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        x, y = data_loader.dataset[i]
        ax.imshow(x.numpy().reshape(28, 28), cmap="gray")
        ax.set_title(f"Label: {y}")
        ax.axis("off")

    fig.tight_layout()
    plt.show()


def get_bash_env_vars(script_path: pathlib.Path) -> dict:
    """
    Get the environment variables from a bash script.
    
    Args:
        script_path (str): The path to the bash script.
    
    Returns:
        dict: The environment variables.
    """
    assert script_path.exists(), f"The bash script {script_path} does not exist."
    cmd = f"source {script_path} && env"
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True, executable="/bin/bash")
    output = process.communicate()[0].decode('utf-8')
    
    env_vars = {}
    for line in output.split('\n'):
        if line and '=' in line:
            key, value = line.split('=', 1)
            env_vars[key] = value
    
    return env_vars