import os
import random
import numpy as np
import torch
from typing import List, Callable
from graphs.combo_graph import create_combo_graph
from functools import partial
from tqdm import tqdm
from multiprocessing import Pool, current_process
from omegaconf import DictConfig

required_params = {
    'combo': {'layers', 'width', 'p', 'm', 'pass_per_layer', 'num_pass'}
}

def make_seed(*parts, mod=2**32):
    base_seed = parts[0]
    h = 0
    for p in parts:
        h = (h * (1_000_000 + base_seed - 1335) + int(p)) & (mod - 1)
    return h

def _init_worker(base_seed: int, chunk_idx: int):
    pid = current_process()._identity[0]
    worker_seed = make_seed(base_seed, chunk_idx, pid)
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)
    torch.cuda.manual_seed_all(worker_seed)

def _generate_single_graph_mp(graph_fns: List[Callable], index: int) -> dict:
    return graph_fns[index]()

def mp_generate_data(
    graph_type: str,
    n: int,
    cfg: DictConfig,
    chunk_idx: int,
    seed: int,
    num_workers: int = None,
    eval_mode: bool = False,
) -> list:
    if graph_type not in required_params:
        raise ValueError(f"Unknown graph type: {graph_type}")
    if graph_type == 'combo':
        _graph_fn = create_combo_graph
    else:
        raise ValueError(f"Unknown graph type: {graph_type}")

    if eval_mode:
        graph_fns = [partial(_graph_fn, **cfg)]
        ratios = [1.0]
    else:
        graph_fns = [partial(_graph_fn, **param) for param in cfg.params]
        ratios = cfg.ratios
        if not np.isclose(sum(ratios), 1):
            ratios = [r / sum(ratios) for r in cfg.ratios]
            print(f"Ratios {cfg.ratios} to {sum(ratios)}; normalizing to {ratios}")
    
    generate_single_graph_mp = partial(_generate_single_graph_mp, graph_fns)
    print(f"Generating {n} graphs with {num_workers or os.cpu_count() // 2} workers")
    worker_args = random.choices(range(len(graph_fns)), k=n, weights=ratios)
    
    with Pool(
        processes=os.cpu_count() // 2,
        initializer=_init_worker,
        initargs=(seed, chunk_idx)
        ) as pool:
        results = pool.map(generate_single_graph_mp, worker_args)
    return results