import random
import numpy
import torch
from typing import Dict


def set_seed(seed: int):
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def calculate_means_dict(array_of_dicts):
    """
    Calculate the mean of each key from an array of dictionaries using PyTorch.

    Parameters:
        array_of_dicts (list of dict): A list of dictionaries with the same keys.

    Returns:
        dict: A dictionary with the mean value for each key.
    """
    if not type(array_of_dicts) == list:
        return {}

    # Extract the keys
    keys = array_of_dicts[0].keys()

    # Convert the list of dictionaries to a tensor
    data_tensor = torch.tensor([[d[key] for key in keys] for d in array_of_dicts])

    # Calculate the mean along the first dimension
    mean_tensor = data_tensor.mean(dim=0)

    # Create the result dictionary mapping keys to their mean values
    mean_dict = {key: mean_tensor[i].item() for i, key in enumerate(keys)}

    return mean_dict


def segment_data_fn(data, obs_dim):
    obs_all = data[..., :obs_dim]
    actions_all = data[..., obs_dim:]
    segments = data
    sa = segments.view(-1, segments.shape[-1])

    return obs_all, actions_all, segments, sa


def nest_dict(
    d: Dict, separator: str = "."
) -> Dict:  # From CPL code: https://github.com/jhejna/cpl
    nested_d = dict()
    for key in d.keys():
        key_parts = key.split(separator)
        current_d = nested_d
        while len(key_parts) > 1:
            if key_parts[0] not in current_d:
                current_d[key_parts[0]] = dict()
            current_d = current_d[key_parts[0]]
            key_parts.pop(0)
        current_d[key_parts[0]] = d[key]  # Set the value
    return nested_d
