import torch
import math
import time
import numpy as np
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel as DDP
import copy


def torch_load_cpu(load_path):
    return torch.load(load_path, map_location=lambda storage, loc: storage)  # Load on CPU

def get_inner_model(model):
    return model.module if isinstance(model, DataParallel) or isinstance(model, DDP) else model

def set_seed(seed=None):
    if seed is None:
        seed=int(time.time())
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

def move_to(var, device):
    if isinstance(var, dict):
        return {k: move_to(v, device) for k, v in var.items()}
    return var.to(device)

def move_to_cuda(var, device):
    if isinstance(var, dict):
        return {k: move_to(v, device) for k, v in var.items()}
    return var.cuda(device)

def clip_grad_norms(param_groups, max_norm=math.inf):
    """
    Clips the norms for all param groups to max_norm and returns gradient norms before clipping
    :param optimizer:
    :param max_norm:
    :param gradient_norms_log:
    :return: grad_norms, clipped_grad_norms: list with (clipped) gradient norms per group
    """
    grad_norms = [
        torch.nn.utils.clip_grad_norm(
            group['params'],
            max_norm if max_norm > 0 else math.inf,  # Inf so no clipping but still call to calc
            norm_type=2
        )
        for idx, group in enumerate(param_groups)
    ]
    grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms
    return grad_norms, grad_norms_clipped

def get_init_cost(env,dataset,bs,seed):
    print('getting init cost...')
    init_costs=[]
    for pro in dataset:
        action=[{'problem':copy.deepcopy(pro),'sgbest':0} for i in range(bs)]
        env.step(action)
        pop=env.reset()
        init_cost_list=[p.gworst_cost for p in pop]
        init_costs.append(np.max(init_cost_list))
    print('done...')
    return np.array(init_costs)

def get_surrogate_gbest(env,dataset,ids,bs,seed,fes):
    print('getting surrogate gbest...')
    gbests={}
    gworsts={}
    set_seed(seed)
    for id,pro in zip(ids,dataset):
        action=[{'problem':copy.deepcopy(pro),'sgbest':0} for i in range(bs)]
        env.step(action)
        env.reset()
        is_done=False
        while not is_done:
            action=[{'fes':fes} for i in range(bs)]
            pop,_,is_done,_=env.step(action)
            is_done=is_done.all()
        gbest_list=[p.gbest_cost for p in pop]
        gbests[id]=np.min(gbest_list)
        gworst_list=[p.gworst_cost for p in pop]
        gworsts[id]=np.max(gworst_list)
    print('done...')
    return gbests, gworsts

import multiprocessing as mp

def get_one_run_surrogate_gbest_subproc(args):
    pro, seed, fes, opts_dict = args
    from pbo_env import MadDE, sep_CMA_ES, PSO, DE
    import numpy as np
    import copy
    import torch
    
    # Set seed for this specific run
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    teacher = opts_dict['teacher']
    dim = opts_dict['dim']
    max_x = opts_dict['max_x']
    min_x = opts_dict['min_x']
    population_size = opts_dict['population_size']

    if teacher == 'madde':
        if opts_dict.get('tea_step') == 'step':
            madde_maxfes = round((fes / population_size) * (4 + 2 * dim * dim) / 2)
        else:
            madde_maxfes = fes
        env = MadDE(dim=dim, problem=copy.deepcopy(pro), max_x=max_x, min_x=min_x, max_fes=madde_maxfes)
    elif teacher == 'cmaes':
        env = sep_CMA_ES(dim=dim, problem=copy.deepcopy(pro), max_x=max_x, min_x=min_x, max_fes=fes, sigma=opts_dict['cmaes_sigma'])
    elif teacher == 'pso':
        env = PSO(ps=population_size, dim=dim, max_fes=fes, min_x=min_x, max_x=max_x, pho=0.2)
    elif teacher == 'de':
        env = DE(dim=dim, ps=population_size, min_x=min_x, max_x=max_x, max_fes=fes)
    else:
        raise ValueError(f"Teacher {teacher} not supported in Subproc task")

    env.step({'problem': copy.deepcopy(pro), 'sgbest': 0})
    env.reset()
    is_done = False
    while not is_done:
        pop, _, is_done, _ = env.step({'fes': fes})
    
    return pop.gbest_cost, pop.gworst_cost

def get_surrogate_gbest_subproc(dataset, ids, bs, seed, fes, opts):
    print(f'getting surrogate gbest using multiprocessing (parallelizing over {len(dataset)} tasks * {bs} runs)...')
    
    if hasattr(opts, '__dict__'):
        opts_dict = opts.__dict__
    else:
        opts_dict = opts
        
    tasks_args = []
    task_map = []
    for i, (id, pro) in enumerate(zip(ids, dataset)):
        for b in range(bs):
            # Use a unique seed for each run: task_seed + run_index
            run_seed = seed + i * 1000 + b 
            tasks_args.append((pro, run_seed, fes, opts_dict))
            task_map.append(id)
    
    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = pool.map(get_one_run_surrogate_gbest_subproc, tasks_args)
    
    gbests = {}
    gworsts = {}
    for id, (best, worst) in zip(task_map, results):
        if id not in gbests:
            gbests[id] = best
            gworsts[id] = worst
        else:
            gbests[id] = min(gbests[id], best)
            gworsts[id] = max(gworsts[id], worst)
            
    print('done...')
    return gbests, gworsts

