import collections.abc
import logging
import os
from collections import defaultdict
from pathlib import Path
from types import SimpleNamespace
from typing import Union, List

import math
import numpy as np
import psutil
import torch
import torch.distributions as tdist
import wandb
from tqdm import tqdm

from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


def shell_command_for_download(target_dir, name=None):
    """ Creates shell command string for downloading target_dir.
    """
    if all(hasattr(config, v) for v in ['local_machine', 'remote_machine']):
        name = name if name else Path(target_dir).name
        username = config.remote_machine['username']
        remotehost = config.remote_machine['remotehost']
        downloaddir = config.local_machine['download_dir']
        logger.debug(
            f'Download {name}:\n'
            f'rsync -hazP {username}@{remotehost}:{target_dir} {downloaddir}')


def get_args_as_string(args: Union[SimpleNamespace, dict]):
    cur_args = vars(args) if isinstance(args, SimpleNamespace) else args
    args_str = str()
    for key, value in cur_args.items():
        args_str += f'{key}: {value}  \n'
    return args_str


def torch_load(dir_, map_location=None):
    if not os.path.isfile(dir_):
        data = None
    else:
        if not map_location:
            map_location = torch.device('cpu') if not torch.cuda.is_available() else None
        data = torch.load(dir_, map_location=map_location)
    return data


def rec_defaultdict():
    """ (recursive) defaultdict of arbitrary depth. """
    return defaultdict(rec_defaultdict)


def log_mean_exp(value, dim=0, keepdim=False):
    return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(
        value.size(dim))


def get_dist(name):
    if name == 'normal':
        return tdist.Normal
    elif name == 'laplace':
        return tdist.Laplace
    else:
        raise ValueError('Distribution not defined.')


def to_np(x):
    if isinstance(x, np.ndarray) or x is None:
        return x
    elif torch.is_tensor(x):
        return x.detach().cpu().numpy()
    else:
        raise TypeError('Unknown data type')


def to_torch(x, dtype=None):
    if x is None:
        return x
    elif isinstance(x, np.ndarray) or torch.is_tensor(x):
        x = _to_torch(x, dtype)
    elif isinstance(x, dict):
        for k, v in x.items():
            x[k] = to_torch(v, dtype)
    elif isinstance(x, tuple) or isinstance(x, list):
        x = [to_torch(v, dtype) for v in x]
    else:
        raise TypeError('Unknown data type')

    return x


def _to_torch(x, dtype=None):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    if dtype:
        x = x.type(dtype)
    return x


def set_seeds(seed=23, backend_seeds=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # multi-GPU models
    np.random.seed(seed)
    if backend_seeds and torch.cuda.is_available():
        # May reduce performance
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def _choose_gpu(device_list: List[int]):
    print(
        f'Expected one GPU, but found {len(device_list)} GPUs. Choosing first '
        f'GPU without any processes.'
    )
    for n in device_list:
        cur = torch.cuda.list_gpu_processes(0)
        if 'no processes are running' in cur:
            return n
    raise RuntimeError(
        'Encountered bug: All available GPUs already have running processes.'
    )


def setup_device():
    print('\nSetting up device:')

    # Set gpu-device numbering to be identical with nvidia-smi
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

    if torch.cuda.is_available():
        # Get free GPU device
        device_list = list(range(torch.cuda.device_count()))
        n_gpus = len(device_list)
        if n_gpus > 1:
            n = _choose_gpu(device_list)
        else:
            n = 0

        # Set GPU device
        device = torch.device(f'cuda:{n}')
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{n}'

        # Print statistics
        name = torch.cuda.get_device_properties(device).name
        mem_total = torch.cuda.get_device_properties(device).total_memory * 10 ** -9
        mem_allocated = torch.cuda.memory_allocated(device) * 10 ** -6
        print('Using one GPU:\n'
              f'- device: {device}\n'
              f'- name: {name}\n'
              f'- total memory: {mem_total:,.2f} GB\n'
              f'- allocated memory: {mem_allocated:,.2f} MB')

    else:
        print('Warning: Cuda not found, using CPU device.')
        device = torch.device('cpu')

    print(
        'CPU:\n'
        f'- total memory: {psutil.virtual_memory().total / 10 ** 9:,.2f} GB\n'
        f'- available memory: {psutil.virtual_memory().available / 10 ** 9:,.2f} GB\n'
    )

    return device


def get_distance(s: torch.tensor, q: torch.tensor, device=None, verbose=False,
                 bs=8192):
    """ Computes squared Euclidean distance
    :param s: (K x N x D) or (N x D)
    :param q: M x D
    :param device: preferably gpu
    :param verbose: whether to show progress bar
    :return: (K x N x M) or (N x M)
    """
    device = device if device else torch.device('cpu')
    flattened = False
    n_old, k = None, None  # shape info only necessary when having flattened
    if len(s.size()) == 3:
        # flatten S to allow smaller batch sizes
        flattened = True
        n_old = s.size(0)
        k = s.size(1)
        s = s.contiguous().view(-1, s.size(-1))
    d = s.size(-1)
    m = q.size(0)
    assert q.size(-1) == d

    def _get_distance(cur_s):
        cur_n = cur_s.size(0)
        cur_s = cur_s[:, None, ...].expand(cur_n, m, d)
        cur_q = q[None, ...].expand(cur_n, m, d)
        return torch.pow(cur_s - cur_q, 2).sum(-1)

    iterator = s.split(bs)
    q = q.to(device)
    tmp = []
    if verbose:
        iterator = tqdm(iterator, position=0, leave=True,
                        desc='Calculating distances')
    for idx, s2 in enumerate(iterator):
        dist = _get_distance(s2.to(device)).cpu()
        tmp.append(dist)
    matrix = torch.cat(tmp)  # N x M

    if flattened:
        # add K-dimension again to preserve consistency with paths which have
        # shape N, and not (N*K),
        matrix = matrix.view(n_old, k, -1)

    return matrix


def to_device(inp, device):
    if isinstance(inp, dict):
        return {k: to_device(v, device) for k, v in inp.items()}
    elif isinstance(inp, list) or isinstance(inp, tuple):
        return [to_device(v, device) for v in inp]
    elif torch.is_tensor(inp):
        return inp.to(device)


# https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356
def update(d, u):
    """ Recursively update dictionary.
    :param d: existing dictionary to be updated
    :param u: dictionary that contains updated values
    :return: updated d
    """
    if u is not None:
        for k, v in u.items():
            if isinstance(v, collections.abc.Mapping):
                d[k] = update(d.get(k, {}), v)
            else:
                d[k] = v
    return d


# inspired by:
# https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys
def flatten_dict(d: dict, parent_key='', sep='_'):
    """ Transforms dict to depth of one by concatenating keys.

    Supports nodes of types [dict, list, scalar]
    """
    items = []
    for k, v in d.items():
        new_key = str(parent_key) + str(sep) + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        elif isinstance(v, list):
            update = {f'{new_key}_{idx}': v for idx, v in enumerate(v)}
            items.extend(flatten_dict(update).items())
        else:
            items.append((new_key, v))
    return dict(items)


def _reduce_dict(data: dict, target: str):
    """ Reduce dictionary to trees with leaf keys that contain "target". """
    out = {}
    for k, v in data.items():
        if isinstance(v, dict):
            if (tmp := _reduce_dict(v, target)) is not None:
                out[k] = tmp

        # List of leaves or subtrees
        elif isinstance(v, list):
            v = [reduce_wrapper(cur_v, target, k) for cur_v in v]
            v = [v2 for v2 in v if v2 is not None]
            if v:
                out[k] = v

        # Single leaf node
        elif target in k:
            # Do not save key "target"
            out = v

    if isinstance(out, dict):
        out = out if out else None

    return out


def _reduce_value(data, target: str, key: str):
    """ Return data if key is in target. """
    return data if target in key else None


def reduce_wrapper(data, target, *args):
    """ Reduce data to target. """
    if isinstance(data, dict):
        return _reduce_dict(data, target)
    else:
        return _reduce_value(data, target, *args)


# https://stackoverflow.com/questions/23499017/know-the-depth-of-a-dictionary/23499101#23499101
def dictionary_depth(d):
    if isinstance(d, dict):
        return 1 + (max(map(dictionary_depth, d.values())) if d else 0)
    return 0


def find_path_of_id(target_id):
    """ Crawls all experiments for current project for ID. """
    tmp = []
    for root, subdirs, files in os.walk(config.dirs['experiments']):
        for subdir in subdirs:
            if subdir == target_id:
                tmp.append(os.path.join(root, subdir))
                return tmp[0]
    if len(tmp) == 0:
        raise ValueError(f'Have not found "{target_id}"')
    elif len(tmp) > 1:
        raise RuntimeError(f'ID "{target_id}" is not unique.')
    return tmp[0]


# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks/312464#312464
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def prepare_string_for_latex(string):
    """
    Removes characters that can cause trouble with latex in matplotlib.
    """
    string = string.replace('_', '\_')
    string = string.replace('#', '\#')
    string = string.replace('$', '\$')
    string = string.replace('%', '\%')
    string = string.replace('&', '\&')
    string = string.replace('{', '\{')
    string = string.replace('}', '\}')
    string = string.replace('~', '\~')
    return string


def get_wandb_project_name(dataset_name: str):
    prefix = 'hmvae'
    if dataset_name == 'flowers' or dataset_name == 'cub':
        suffix = 'images'
    elif 'ft' in dataset_name:
        suffix = 'features'
    elif 'synthetic' in dataset_name:
        suffix = 'synthetic'
    else:
        raise ValueError(f'{dataset_name} is not a legal dataset name.')
    project_name = f'{prefix}_{suffix}'
    return project_name


def init_wandb(run_id: str,
               project: str,
               group: str,
               wandb_config: dict,
               **kwargs):
    """ Wraps initialization for https://wandb.ai/

    Note that Wandb cannot easily handle several processes simultaneously, e.g.,
    this function may need to be called several times during training.

    :param run_id: run_id outside wandb
    :param project: project name in wandb
    :param group: group in wandb (e.g., 'train')
    :param wandb_config: hyperparameters to be saved in wandb
    :param kwargs: kwargs for wandb.init()
    """
    wandb.finish()  # Close previous process (if existent)
    wandb_id = _get_wandb_id(run_id, project, group)
    wandb.init(
        project=project,
        entity=config.wandb["entity"],
        id=wandb_id,
        group=group,
        **kwargs
    )
    wandb.config.update(wandb_config)


def _get_wandb_id(run_id: str, project: str, group: str):
    """
    Problem: Challenging to delete data from https://wandb.ai/
    - Not possible to remove all data from an existing run
    - Not possible to reuse the same ID after deletion

    :return: ID used inside wandb
        Existing ID if existent
        New ID if none is found
    """
    api = wandb.Api()
    runs = api.runs(f'{config.wandb["entity"]}/{project}')

    wandb_id = None
    try:
        for run in runs:
            # Avoid catching run during its initialization phase
            if 'run_id' in run.config:
                if all([run.config['run_id'] == run_id,
                        run.group == group]):
                    if wandb_id:
                        raise Exception(
                            f'Found duplicate run for {run_id} in wandb-database'
                        )
                    wandb_id = run.id
    except ValueError:
        # Project not existent
        pass

    return wandb_id


def check_debug_status(debug_mode=False, resume_id=None):
    """
    Checks if code runs in debug-mode.
    """
    if resume_id and 'debug' in resume_id:
        # Run_path and run_id may have become inconsistent
        debug = True
        run_path = find_path_of_id(resume_id)
        src = os.path.join(run_path, 'args.pt')
        args = torch_load(src)
        args.run_id = Path(run_path).name
        torch.save(args, src)
    elif debug_mode:
        debug = True
    else:
        debug = False
    return debug
