import random
import os
import torch
import numpy as np


def set_seed(value):
    print("Random Seed: ", value)
    random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed_all(value)
    torch.backends.cudnn.deterministic = True
    np.random.seed(value)


class DictToObject:
    def __init__(self, dictionary):
        for key in dictionary:
            setattr(self, key, dictionary[key])


def to_device(data, device):
    # if data is tuple, convert each item to the specific device, if data is dict, convert each item to the specific device, if data is tensor, directly convert it
    if isinstance(data, tuple) or isinstance(data, list):
        return tuple(item.to(device) for item in data)
    if isinstance(data, dict):
        return {key: value.to(device) for key, value in data.items()}
    return data.to(device, non_blocking=True)


def get_optimizer_params(model, weight_decay):
    # start with all of the candidate parameters
    param_dict = {pn: p for pn, p in model.named_parameters()}
    # filter out those that do not require grad
    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
    # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
    # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
    optim_groups = [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': nodecay_params, 'weight_decay': 0.0}
    ]
    num_decay_params = sum(p.numel() for p in decay_params)
    num_nodecay_params = sum(p.numel() for p in nodecay_params)
    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
    
    return optim_groups