import os
import ast
import shutil
from copy import deepcopy
import importlib.util
from pathlib import Path
import sys
import git
import socket
import subprocess
from datetime import datetime

import numpy as np
import shortuuid
from loguru import logger
from omegaconf import DictConfig
import wandb
from omegaconf import OmegaConf


def launch_wandb(wandb_cfg: DictConfig, run_cfg: DictConfig):
    if not wandb_cfg.enable:
        return
    dict_cfg = OmegaConf.to_container(run_cfg)
    dict_cfg['git'] = git_commit_id()
    wandb.init(
        config=dict_cfg,
        project=wandb_cfg.project,
        entity=wandb_cfg.entity,
        name=wandb_cfg.get('name', None),
        tags=wandb_cfg.get('tags', [])
    )


def flatten_dict_cfg(cfg):  # [dict | DictConfig]) -> DictConfig:
    '''
    replace '.' with '_' in a dict, e.g.
    >>> flatten_dict_cfg({'a': {'c': 'd', 'e': {'f': 'g', 'h': 'i'}}, 'b': 3})
    {'a_c': 'd', 'a_e_f': 'g', 'a_e_h': 'i', 'b': 3}
    '''
    ret = {}
    if isinstance(cfg, dict):
        cfg = DictConfig(cfg)
    for k, v in cfg.items():
        if isinstance(v, DictConfig):
            ret_v = flatten_dict_cfg(v)
            for _k, _v in ret_v.items():
                ret[f'{k}_{_k}'] = _v
        else:
            ret[k] = v
    return DictConfig(ret)


def terminate_wandb(wandb_cfg):
    if wandb_cfg.enable:
        wandb.finish()


def hash_str(s):
    y = 0
    for x in s:
        y = (y * 10007 + ord(x)) % 998244353
    return y


def current_time():
    current_time = datetime.now()
    readable_time = current_time.strftime("%Y-%m-%d-%H:%M:%S")
    return readable_time


def get_os_command_out(os_command):
    try:
        result = subprocess.run([os_command], capture_output=True, text=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        ret = f'Error executing "{os_command}": {e}'
        return ret


def update_summary(summary, to_log, prefix='summary/'):
    for _k, v in to_log.items():
        k = f'{prefix}{_k}'
        if k not in summary:
            summary[k] = []
        summary[k].append(v)

def find_unused_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))  # Bind to a free port provided by the host
        return str(s.getsockname()[1])  # Return the port number assigned


def split_list(list_to_split, num_split):
    ret = []
    items_per_split = len(list_to_split) // num_split + 1
    idx = 0
    for i in range(num_split):
        if i == len(list_to_split) % num_split:
            items_per_split -= 1
        ret.append(list_to_split[idx: idx + items_per_split])
        idx += items_per_split
    assert idx == len(list_to_split)
    return ret


def uuid(length=8):
    """
    https://github.com/wandb/client/blob/master/wandb/util.py#L677
    """

    # ~3t run ids (36**8)
    run_gen = shortuuid.ShortUUID(alphabet=list("0123456789abcdefghijklmnopqrstuvwxyz"))
    return run_gen.random(length)


def pathlib_file(file_name):
    if isinstance(file_name, str):
        file_name = Path(file_name)
    elif not isinstance(file_name, Path):
        raise TypeError(f'Please check the type of the filename:{file_name}')
    return file_name


def filter_str_with_exclude_patterns(filenames, exclude_patterns):
    if exclude_patterns is None:
        return filenames
    if isinstance(exclude_patterns, str):
        files = [x for x in filenames if exclude_patterns not in x]
    else:
        out = []
        for x in filenames:
            contain_patterns = [epat in x for epat in exclude_patterns]
            if not any(contain_patterns):
                out.append(x)
        files = out
    return files


def filter_str_with_include_patterns(filenames, include_patterns):
    if include_patterns is None:
        return filenames
    if isinstance(include_patterns, str):
        files = [x for x in filenames if include_patterns in x]
    else:
        out = []
        for x in filenames:
            contain_patterns = [ipat in x for ipat in include_patterns]
            if any(contain_patterns):
                out.append(x)
        files = out
    return files


def filter_with_exclude_patterns(filenames, exclude_patterns):
    if exclude_patterns is None:
        return filenames
    if isinstance(exclude_patterns, str):
        files = [x for x in filenames if exclude_patterns not in x.as_posix()]
    else:
        out = []
        for x in filenames:
            contain_patterns = [epat in x.as_posix() for epat in exclude_patterns]
            if not any(contain_patterns):
                out.append(x)
        files = out
    return files


def filter_with_include_patterns(filenames, include_patterns):
    if include_patterns is None:
        return filenames
    if isinstance(include_patterns, str):
        files = [x for x in filenames if include_patterns in x.as_posix()]
    else:
        out = []
        for x in filenames:
            contain_patterns = [ipat in x.as_posix() for ipat in include_patterns]
            if any(contain_patterns):
                out.append(x)
        files = out
    return files


def get_all_subdirs(directory, exclude_patterns=None, include_patterns=None, sort=True):
    directory = pathlib_file(directory)
    if not directory.exists():
        return []
    files = list(directory.iterdir())
    files = [x for x in files if x.is_dir()]
    if exclude_patterns is not None:
        files = filter_with_exclude_patterns(files, exclude_patterns)
    if include_patterns is not None:
        files = filter_with_include_patterns(files, include_patterns)
    if sort:
        files = sorted(files)
    return files


def get_all_subfiles(directory, exclude_patterns=None, include_patterns=None, sort=True):
    directory = pathlib_file(directory)
    if not directory.exists():
        return []
    files = list(directory.iterdir())
    files = [x for x in files if x.is_file()]
    if exclude_patterns is not None:
        files = filter_with_exclude_patterns(files, exclude_patterns)
    if include_patterns is not None:
        files = filter_with_include_patterns(files, include_patterns)
    if sort:
        files = sorted(files)
    return files


def get_all_files_with_suffix(directory, suffix,
                              exclude_patterns=None,
                              include_patterns=None,
                              sort=True,
                              max_num=-1, rand_choice=True):
    directory = pathlib_file(directory)
    if not directory.exists():
        return []
    if not suffix.startswith('.'):
        suffix = '.' + suffix
    files = directory.glob(f'**/*{suffix}')
    files = [x for x in files if x.is_file() and x.suffix == suffix]
    if exclude_patterns is not None:
        files = filter_with_exclude_patterns(files, exclude_patterns)
    if include_patterns is not None:
        files = filter_with_include_patterns(files, include_patterns)
    if sort:
        files = sorted(files)
    if max_num != -1 and max_num < len(files):
        if rand_choice:
            selected_files_idx = np.random.choice(len(files), max_num, replace=False)
            _files = deepcopy(files)
            files = [_files[idx] for idx in selected_files_idx]
        else:
            files = files[:max_num]
    return files


def get_all_files_with_name(directory, name,
                            exclude_patterns=None,
                            include_patterns=None,
                            sort=True,
                            ):
    directory = pathlib_file(directory)
    if not directory.exists():
        return []
    files = directory.glob(f'**/{name}')
    files = [x for x in files if x.is_file() and x.name == name]
    if exclude_patterns is not None:
        files = filter_with_exclude_patterns(files, exclude_patterns)
    if include_patterns is not None:
        files = filter_with_include_patterns(files, include_patterns)
    if sort:
        files = sorted(files)
    return files


def list_class_names(dir_path):
    """
    Return the mapping of class names in all files
    in dir_path to their file path.
    Args:
        dir_path (str): absolute path of the folder.
    Returns:
        dict: mapping from the class names in all python files in the
        folder to their file path.
    """
    dir_path = pathlib_file(dir_path)
    py_files = list(dir_path.rglob('*.py'))
    py_files = [f for f in py_files if f.is_file() and f.name != '__init__.py']
    cls_name_to_path = dict()
    for py_file in py_files:
        with py_file.open() as f:
            node = ast.parse(f.read())
        classes_in_file = [n for n in node.body if isinstance(n, ast.ClassDef)]
        cls_names_in_file = [c.name for c in classes_in_file]
        for cls_name in cls_names_in_file:
            cls_name_to_path[cls_name] = py_file
    return cls_name_to_path


def load_class_from_path(cls_name, path):
    mod_name = 'MOD%s' % cls_name
    spec = importlib.util.spec_from_file_location(mod_name, path)
    mod = importlib.util.module_from_spec(spec)
    sys.modules[cls_name] = mod
    spec.loader.exec_module(mod)
    return getattr(mod, cls_name)


def git_commit_id(path=None, search_par_dir=True):
    repo = git.Repo(path=(None if (path is None) else pathlib_file(path).as_posix()), search_parent_directories=search_par_dir)
    return repo.head.object.hexsha


def ask_del_dir(dir_path, ask_del=False):
    '''
    check whether dir_path exists, if exists, ask user whether to delete it or not by ask_del
    if dir_path need to be deleted, performs the operation
    :param dir_path:
    :param ask_del:
    :return: False, if dir_path exists and is not deleted, otherwise True
    '''
    f = pathlib_file(dir_path)
    if f.exists():
        if ask_del:
            while True:
                op = input(f'{f.resolve().as_posix()} exists, delete it? [y/n]')
                if op in ['y', 'n']:
                    if op == 'n':
                        return False
                    else:
                        break
        shutil.rmtree(f.as_posix())
        logger.info(f'{f.resolve().as_posix()} deleted')
    return True


def wait_for_proc(proc):
    while True:
        output = proc.stdout.readline()
        print(output.strip())
        return_code = proc.poll()
        if return_code is not None:
            print('RETURN CODE', return_code)
            # Process has finished, read rest of the output
            for output in proc.stdout.readlines():
                print(output.strip())
            break


def run_cmd(cmd):
    logger.info(f'====================================================================')
    logger.info(f'RUNNING CMD: \n {cmd}')
    logger.info(f'====================================================================')
    proc = subprocess.Popen(cmd,
                            shell=True,
                            stderr=subprocess.STDOUT,
                            stdout=subprocess.PIPE,
                            universal_newlines=True
                            )

    wait_for_proc(proc=proc)
    logger.info(f'CMD [{cmd}] finished!')
    logger.info(f'-------------------------------------------------')


def run_cmd_no_wait(cmd):
    logger.info(f'====================================================================')
    logger.info(f'RUNNING CMD: \n {cmd}')
    logger.info(f'====================================================================')
    proc = subprocess.Popen(cmd,
                            shell=True,
                            stderr=subprocess.STDOUT,
                            stdout=subprocess.PIPE,
                            universal_newlines=True
                            )


def append_item_to_dict(d: dict, ks: list, v):
    '''
    run d[ks[0]][ks[1]]...[ks[-1]].append(v) with filling empty keys
    :param d:
    :param ks:
    :param v:
    :return:
    '''
    k = ks[0]
    if len(ks) == 1:
        if k not in d:
            d[k] = [v]
        else:
            d[k].append(v)
    else:
        if k not in d:
            d[k] = dict()
        append_item_to_dict(d[k], ks[1:], v)


def assign_item_to_dict(d: dict, ks: list, v):
    '''
    run d[ks[0]][ks[1]]...[ks[-1]] = v with filling empty keys
    :param d:
    :param ks:
    :param v:
    :return:
    '''
    k = ks[0]
    if len(ks) == 1:
        d[k] = v
    else:
        if k not in d:
            d[k] = dict()
        assign_item_to_dict(d[k], ks[1:], v)


def _symlink_folder(src, dst):
    dst.mkdir(parents=True, exist_ok=False)
    for x in src.iterdir():
        if x.is_file():
            os.symlink(x, dst / x.name)
        else:
            _symlink_folder(x, dst / x.name)


def symlink_folder(src, dst):
    '''
    os.symlink(src, dst), instead of creating symlink on folders, create symlinks for all files
    :param src:
    :param dst:
    :return:
    '''
    src, dst = pathlib_file(src), pathlib_file(dst)
    assert src.is_dir()
    dst.parent.mkdir(parents=True, exist_ok=True)
    _symlink_folder(src, dst)
