import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb


class ConfigNamespace:
    def __init__(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                setattr(self, k, ConfigNamespace(v))
            else:
                setattr(self, k, v)
                
                
def set_wandb(args, config_dict):
    prop_type = args.prop_type
    seed = args.seed
    date_str = args.date_str
    if '_x' in prop_type:
        prop_type = prop_type.replace('_x', '')
        name=f"{prop_type}_x_{args.model_name}_{seed}_{date_str}"
    else:
        name=f"{prop_type}_y_{args.model_name}_{seed}_{date_str}"
        
    wandb.init(
        project=args.proj_name,
        config={**vars(args), **config_dict},
        name = name
    )

def set_seed(args):
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    
def log_params(model):
    """Logs the total, trainable, and non-trainable parameters of a model."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {non_trainable_params:,}")
    return total_params, trainable_params


