import os
import random
import pickle

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb
from rdkit import Chem
from chemprop.features.featurization import mol2graph


class ConfigNamespace:
    def __init__(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                setattr(self, k, ConfigNamespace(v))
            else:
                setattr(self, k, v)
                
                
def set_wandb(args, config_dict):
    prop_type = args.prop_type
    seed = args.seed
    date_str = args.date_str
    if '_x' in prop_type:
        prop_type = prop_type.replace('_x', '')
        name=f"{prop_type}_x_{args.model_name}_{seed}_{date_str}"
    else:
        name=f"{prop_type}_y_{args.model_name}_{seed}_{date_str}"
        
    wandb.init(
        project=args.proj_name,
        config={**vars(args), **config_dict},
        name = name,
        settings=wandb.Settings(_disable_service=True)
    )

def set_seed(args):
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    
def log_params(model):
    """Logs the total, trainable, and non-trainable parameters of a model."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {non_trainable_params:,}")
    
    non_trainable = [(name, p.shape) for name, p in model.named_parameters() if not p.requires_grad]
    print(f"\nNon-trainable parameters ({len(non_trainable)} tensors):")
    for name, shape in non_trainable:
        print(f"  {name}: {tuple(shape)}")
    

def load_dataset(args, split):
    data_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, f"{args.prop_type}.pkl")
    all_dataset = pickle.load(open(data_path, 'rb'))
    
    graphs = all_dataset[split]['graphs']
    smiles = all_dataset[split]['smiles']
    targets = all_dataset[split]['targets']

    for i, g in enumerate(graphs):
        g.y = torch.tensor(targets[i], dtype=torch.float).view(1, -1)

    return graphs, smiles, targets


def load_chemprop_dataset(args, split):
    data_path = os.path.join(args.data_dir, args.dataset_name, args.dataset_split_type, args.prop_type, f"{args.prop_type}.pkl")
    all_dataset = pickle.load(open(data_path, 'rb'))
    
    smiles = all_dataset[split]['smiles']
    targets = all_dataset[split]['targets']
    graphs = [mol2graph([Chem.MolFromSmiles(s)]) for s in smiles]
    
    return graphs, smiles, targets
