import os

from accelerate import Accelerator
import random
import numpy as np
import torch
import transformers


_accelerate = Accelerator()

def is_main_process():
    return _accelerate.is_main_process


def print_0(*args, **kwargs):
    if is_main_process():
        kwargs['flush'] = True
        print(_accelerate.device, *args, **kwargs)


def init_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)

    os.environ['PYTHONHASHSEED'] = str(seed)

    from torch.backends import cudnn
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)


def get_device():
    return _accelerate.device

def print_output(outputs):
    if not isinstance(outputs, dict):
        return
    for key, val in outputs.items():
        print_0(type(val), isinstance(val, tuple))
        if isinstance(val, tuple):
            for _idx, _val in enumerate(val):
                if isinstance(_val, torch.Tensor):
                    print_0(key, _idx, _val.shape)
                elif isinstance(_val, tuple):
                    for __idx, __val in enumerate(_val):
                        if isinstance(__val, torch.Tensor):
                            print(key, _idx, __idx, __val.shape)
        else:
            print_0(key, val.shape)


