import wandb
import tqdm
import numpy as np
import tensorflow as tf
import einops

import jax
import jaxlib

device_list = jax.devices()
num_devices = len(device_list)

def split_to_devices(x):
    def split_tensor(x: tf.Tensor):
        if isinstance(x, tf.Tensor):
            x: np.ndarray = x._numpy()
        return einops.rearrange(x, "(d b) ... -> d b ...", d=num_devices)

    return jax.tree_map(split_tensor, x)

# average a dictionary where each value is a list
def average_dict(some_dict):
    new_dict = {}
    for key, value in some_dict.items():
        if type(value) == dict:
            new_dict[key] = average_dict(value)
        elif type(value) == jaxlib.xla_extension.ArrayImpl:
            new_dict[key] = np.mean(value.tolist())
        elif type(value) == list:
            new_dict[key] = np.mean(value)
        else:
            new_dict[key] = value
    return new_dict

# average a list of dictionaries with the same keys
def average_dicts(some_dicts):
    averaged = {}
    for key in some_dicts[0].keys():
        val_type = type(some_dicts[0][key])
        if val_type == dict:
            averaged[key] = average_dicts([i[key] for i in some_dicts])
        elif val_type == np.float64:  # they're numbers!
            averaged[key] = np.mean([i[key] for i in some_dicts])
        else:
            print("val type", val_type, "not recognized")
    return averaged


def print_dict_keys(d, indent=0):
    for key, value in d.items():
        print(" " * indent + f"{key}")
        if isinstance(value, dict):
            print_dict_keys(value, indent + 1)