import json
import torch
import itertools
import numpy as np

from skimage.transform import resize
from torch.utils.data import Dataset
###############################################################################
''' CONFIG HANDLING '''
###############################################################################

def save_config_json(save_config, filename):
    '''
    Saves Configuration dicts in json file
    Parameters:
            save_config (dict): a dict of dicts that are the configuration, the dicts only should have one subdict layer.
            filename (str): name of the file to save to, should end on .json.
    '''
    # Modify function to only save as string
    for k, v in save_config.items():
        if isinstance(v, dict):
            for k1, v1 in v.items():
                if callable(v1):
                    v[k1] = v1.__name__
                elif isinstance(v1, dict):
                    for k2, v2 in v1.items():
                        if isinstance(v2, list):
                            for i in range(len(v2)):
                                if callable(v2[i]):
                                    v2[i] = v2[i].__name__
                        if callable(v2):
                            v1[k2] = v2.__name__

    if isinstance(filename,str):
        with open(filename, 'w+') as fp:
            json.dump(save_config, fp, indent=4)
    else:
        json.dump(save_config, filename, indent=4)

def load_config_json(filename, loc):
    '''
    Load Configuration dicts from json file. This json files supports "//" comments
    Parameters:
            filename (str): name of the file to load from, should end on .json.
            loc (object): should be locals() or globals() to find the local functions in your script.
    '''

    # remove comments starting with '//'
    json_str = ''
    with open(filename, 'r') as f:
        for line in f:
            line = line.split('//')[0] + '\n'
            json_str += line
    load_config = json.loads(json_str)

    # Get back function handle
    for k, v in load_config.items():
        if isinstance(v, dict):
            for k1, v1 in v.items():
                if isinstance(v1, str) and v1 in loc and callable(eval(v1)):
                    v[k1] = eval(v1)
                elif isinstance(v1, dict):
                    for k2, v2 in v1.items():
                        if isinstance(v2, list):
                            for i in range(len(v2)):
                                if isinstance(v2[i], str) and v2[i] in loc and callable(eval(v2[i])):
                                    v2[i] = eval(v2[i])
                        elif isinstance(v2, str) and v2 in loc and callable(eval(v2)):
                            v1[k2] = eval(v2)
    
    return [v for k,v in load_config.items()]



def numpy_to_torch_tensor(data, pad=0):
    if pad > 0:
        pad = tuple([(0,0),(0,0)] + [tuple([pad,pad]) for i in range(3)])
        data = [np.pad(img, pad, 'symmetric') for img in data] 
    return [torch.from_numpy(img).type(torch.FloatTensor) for img in data]
    