import random
import numpy as np
import torch
import dgl
from itertools import product

def assign_hyp_param(args, params):
    for key, value in params.items():
        if key in ['lr', 'epochs', 'gain']:
            setattr(args, key, value)
        else:
            method_args_attr = f"{args.method}_args"
            if hasattr(args, method_args_attr):
                getattr(args, method_args_attr)[key] = value

def compose_hyper_params(hyp_params, lr_list, epochs_list, gain_list=None):
    hyp_param_values = [hyp_params[hk] for hk in hyp_params]
    hyp_param_values.append(lr_list)
    hyp_param_values.append(epochs_list)
    if gain_list:
        hyp_param_values.append(gain_list)
    combinations = product(*hyp_param_values)
    keys = list(hyp_params.keys()) + ['lr', 'epochs']
    if gain_list:
        keys.append('gain')
    return [dict(zip(keys, combination)) for combination in combinations]

def format_hyp_params(hyp_params):
    return str(hyp_params).replace("'", '').replace(' ', '').replace(',', '_').replace(':', '_')

def format_args(args):
    if args.method in ['er', 'agem', 'ssmer', 'ssmagem', 'pdgnn']:
        return f'{args.dataset}_{args.IL_stream}_{args.method}_budget{args.memory_budget}_{args.n_nodes_per_batch}_{args.backbone}'
    return f'{args.dataset}_{args.IL_stream}_{args.method}_{args.n_nodes_per_batch}_{args.backbone}'

def str2dict(s):
    # accepts a str like " 'k1':v1; ...; 'km':vm ", values (v1,...,vm) can be single values or lists (for hyperparameter tuning)
    output = dict()
    kv_pairs = s.replace(' ','').replace("'",'').split(';')
    for kv in kv_pairs:
        key = kv.split(':')[0]
        v_ = kv.split(':')[1]
        if '[' in v_:
            # transform list of values
            v_list = v_.replace('[','').replace(']','').split(',')
            vs=[]
            for v__ in v_list:
                try:
                    # if the parameter is int
                    if str(int(v__)) == v__:
                        vs.append(int(v__))
                    # if the parameter is float
                    else:
                        vs.append(float(v__))
                except:
                    # if the parameter is str
                    vs.append(str(v__))
            output.update({key:vs})
        else:
            try:
                output.update({key: float(v_)})
            except:
                output.update({key: str(v_)})
    return output

def set_seed(args=None):
    seed = 1 if not args else args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    dgl.random.seed(seed)

def shuffle_list(input_list, random_seed=42):
    random.seed(random_seed)
    random.shuffle(input_list)
    return input_list

def shuffle_tensor(input_tensor, random_seed=42, return_indices=False):
    rng_state = torch.get_rng_state()
    torch.manual_seed(random_seed)
    indices = torch.randperm(input_tensor.size(0))
    torch.set_rng_state(rng_state)
    if return_indices:
        return input_tensor[indices], indices
    return input_tensor[indices]
