from spaghettini import quick_register
import os
import oyaml as yaml
from pprint import pprint
import numpy as np
import random as random
import torch
from datetime import datetime
from argparse import Namespace
from typing import MutableMapping, Callable

from torch.nn.utils import remove_spectral_norm

import wandb

from pytorch_lightning import seed_everything

USE_GPU = torch.cuda.is_available()


def to_cuda(xs):
    if type(xs) is not list and type(xs) is not tuple:
        return xs.cuda() if USE_GPU else xs
    items = list()
    for curr_item in xs:
        curr_item = curr_item.cuda() if USE_GPU else curr_item
        items.append(curr_item)

    return items


def average_values_in_list_of_dicts(list_of_dicts):
    averaged_values_dict = dict()
    for curr_output_dict in list_of_dicts:
        for k, v in curr_output_dict.items():
            if k not in averaged_values_dict:
                averaged_values_dict[k] = [v]
            else:
                averaged_values_dict[k].append(v)
    averaged_scalar_metrics_dict = dict()
    for k, v in averaged_values_dict.items():
        try:
            averaged_scalar_metrics_dict[k] = np.array(v).mean()
        except:
            print("Skipping any non-scalar metric that was logged. ")

    return averaged_scalar_metrics_dict


def prepend_string_to_dict_keys(prepend_key, dictinary):
    return {"{}{}".format(prepend_key, k): v for k, v in dictinary.items()}


def print_experiment_config(path="."):
    yaml_path = os.path.join(path, "template.yaml")
    config_dict = yaml.safe_load(open(yaml_path))
    pprint(config_dict)


def sendline_and_get_response(s, line):
    s.sendline(line)
    s.prompt()
    reply = str(s.before.decode("utf-8"))
    pprint(reply)


def getnow(return_int=False):
    now = datetime.now()
    dt_string = now.strftime("%Y_%m_%d_%H_%M_%S")
    if return_int:
        return int(dt_string)
    else:
        return dt_string


def get_num_params_of_pytorch_model(module):
    model_parameters = filter(lambda p: p.requires_grad, module.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])


def get_num_of_allocated_tensors():
    import torch
    import gc
    count = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                count += 1
        except:
            pass
    return count


def set_seed(seed=None):
    seed = getnow(return_int=True) % 2 ** 32 if seed is None else seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


@quick_register
def seed_workers(worker_id):
    # Used to make sure Pytorch dataloaders don't return identical random numbers amongst different workers.
    set_seed()


@quick_register
def freeze_thaw(epoch, freeze, thaw):
    if freeze <= epoch < thaw:
        return 0.
    else:
        return 1.


def set_hyperparams(config_path, logger):
    with open(config_path, 'r') as f:
        x = yaml.safe_load(f)
        logger.log_hyperparams(x)


def set_hyperparams_pure_wandb(config_path):
    with open(config_path, 'r') as f:
        x = yaml.safe_load(f)
        x = _convert_params(x)
        x = _flatten_dict(x)
        x = _sanitize_callable_params(x)
        wandb.config.update(x, allow_val_change=True)


def _convert_params(params):
    # Taken from pytorch lightning codebase.
    # in case converting from namespace
    if isinstance(params, Namespace):
        params = vars(params)

    if params is None:
        params = {}

    return params


def _flatten_dict(params, delimiter: str = '/'):
    # Taken from pytorch lightning codebase.

    def _dict_generator(input_dict, prefixes=None):
        prefixes = prefixes[:] if prefixes else []
        if isinstance(input_dict, MutableMapping):
            for key, value in input_dict.items():
                key = str(key)
                if isinstance(value, (MutableMapping, Namespace)):
                    value = vars(value) if isinstance(value, Namespace) else value
                    for d in _dict_generator(value, prefixes + [key]):
                        yield d
                else:
                    yield prefixes + [key, value if value is not None else str(None)]
        else:
            yield prefixes + [input_dict if input_dict is None else str(input_dict)]

    return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}


def _sanitize_callable_params(params):
    # Taken from pytorch lightning codebase.
    def _sanitize_callable(val):
        # Give them one chance to return a value. Don't go rabbit hole of recursive call
        if isinstance(val, Callable):
            try:
                _val = val()
                if isinstance(_val, Callable):
                    return val.__name__
                return _val
            except Exception:
                return getattr(val, "__name__", None)
        return val

    return {key: _sanitize_callable(val) for key, val in params.items()}


def enlarge_matplotlib_defaults(plt_object):
    # plt_object.rc('legend', fontsize='medium')
    plt_object.rc('font', size=15)


def recursively_remove_spectral_norm(module):
    def remove_sn(m):
        try:
            remove_spectral_norm(module=m)
        except:
            pass
    module.apply(remove_sn)
