import torch
import numpy as np
import random
import pynvml
import logging
import json
from easydict import EasyDict as edict


logger = logging.getLogger('MMSA')


def dict_to_str(src_dict):
    dst_str = ""
    for key in src_dict.keys():
        dst_str += " %s: %.4f " %(key, src_dict[key]) 
    return dst_str

def setup_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def assign_gpu(gpu_ids, memory_limit=1e16):
    if len(gpu_ids) == 0 and torch.cuda.is_available():
        # find most free gpu
        pynvml.nvmlInit()
        n_gpus = pynvml.nvmlDeviceGetCount()
        dst_gpu_id, min_mem_used = 0, memory_limit
        for g_id in range(n_gpus):
            handle = pynvml.nvmlDeviceGetHandleByIndex(g_id)
            meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
            mem_used = meminfo.used
            if mem_used < min_mem_used:
                min_mem_used = mem_used
                dst_gpu_id = g_id
        logger.info(f'Found gpu {dst_gpu_id}, used memory {min_mem_used}.')
        gpu_ids.append(dst_gpu_id)
    # device
    using_cuda = len(gpu_ids) > 0 and torch.cuda.is_available()
    # logger.info("Let's use %d GPUs!" % len(gpu_ids))
    device = torch.device('cuda:%d' % int(gpu_ids[0]) if using_cuda else 'cpu')
    return device

def count_parameters(model):
    res = 0
    for p in model.parameters():
        if p.requires_grad:
            res += p.numel()
            # print(p)
    return res


class ConfigParser:
    def __init__(self, config_path):
        with open(config_path, 'r') as f:
            self.config = json.load(f)
    
    def get_params(self, dataset_name, model_name):
        dataset_common = self.config['datasetCommonParams'][dataset_name]['aligned']
        
        model_common = self.config[model_name]['commonParams']
        
        model_dataset = self.config[model_name]['datasetParams'][dataset_name]
        
        params = {}
        params.update(dataset_common)
        params.update(model_common)
        params.update(model_dataset)
        
        params['featurePath'] = f"{self.config['datasetCommonParams']['dataset_root_dir']}/{params['featurePath']}"

        config = edict(params)
        
        return config

def get_subset_distribution(dataloader):
    labels = []
    for batch in dataloader:
        batch_labels = batch['labels']['M'].numpy()
        labels.extend(batch_labels)
    
    labels = np.array(labels)
    non_zeros = np.array([i for i in range(len(labels)) if labels[i] != 0])
    
    if len(non_zeros) > 0:
        non_zeros_labels = labels[non_zeros]
        positive_nz = np.sum(non_zeros_labels > 0)
        negative_nz = np.sum(non_zeros_labels < 0)
        total_nz = len(non_zeros_labels)

        pos_ratio_nz = positive_nz/total_nz*100
        neg_ratio_nz = negative_nz/total_nz*100
    
        print(f"\nSubset label distribution:")
        print(f"Positive samples (>0): {positive_nz} ({pos_ratio_nz:.2f}%)")
        print(f"Negative samples (<0): {negative_nz} ({neg_ratio_nz:.2f}%)")
        print(f"Total non-zero samples: {total_nz}")
        print(f"Majority class ratio: {max(pos_ratio_nz, neg_ratio_nz):.2f}%")
