import os
import sys
from typing import Any, Dict, List, Optional, Tuple

import torch
from ruamel import yaml


class ConfigurationException(Exception):
    pass


def boolean_nested_and(outer_mask: torch.Tensor, inner_mask: torch.Tensor) -> torch.Tensor:
    """
    Performs a logical and operation between the two boolean masks
    assuming that the inner mask only refers to the True values in the outer mask

    >>> aa = Tensor([False, True, True, True, False, False])
    >>> bb = Tensor([True, False, True])
    >>> cc = Tensor([False, True, False, True, False, False])
    >>> assert torch.all(cc == _boolean_nested_and(aa, bb))
    """
    outer_mask = outer_mask.clone().detach()
    rr = torch.arange(0, len(outer_mask))
    outer_mask[rr[outer_mask][~inner_mask]] = False
    return outer_mask


def flatten_dict(d: Dict[str, Any], key_separator: str = '.') -> Dict[str, Any]:
    res = {}
    for k, v in d.items():
        if isinstance(v, dict):
            f = flatten_dict(v)
            for l, w in f.items():
                new_key = f'{k}{key_separator}{l}'
                assert new_key not in res
                res[new_key] = w
        else:
            res[k] = v
    return res


def update_config_with_arg(config: Dict[str, Any], key: str, value: Any
                           ) -> Optional[Tuple[str, Any, Any]]:
    def parse(old: Any, value: Any) -> Any:
        if old is None:
            return value

        t = type(old)
        if t is bool:
            return value.lower() in ('y', 'yes', 't', 'true')
        elif t is list or t is tuple or t is dict:
            raise ValueError('cannot parse lists, dicts or tuples')
        elif value.lower() == 'none':
            return None
        else:
            return t(value)

    *path, name = key.split('.')
    if len(path) > 0:
        # keys with dots are a full path into nested dictionaries
        # a.b.c will map to config[a][b][c]
        cursor = config
        for k in path:
            cursor = cursor[k]

        old = cursor.get(name)
        cursor[name] = parse(old, value)

        return key, old, cursor[name]
    else:
        # keys without dots refer directly to leaves
        # thus we must seek where in the config they are
        def traverse_dict(p: List[str], d: Dict[str, Any]
                          ) -> Optional[Tuple[str, Any, Any]]:
            for k, v in d.items():
                if k == name:
                    old = v
                    d[k] = parse(v, value)
                    return '.'.join(p + [k]), old, d[k]
                elif isinstance(v, dict):
                    res = traverse_dict(p + [k], v)
                    if res is not None:
                        return res
            return None

        return traverse_dict([], config)


def get_config() -> Dict[str, Any]:
    if '--help' in sys.argv or '-h' in sys.argv:
        print('Usage: puupl-train [path/to/config.yaml] [arg.1:val.1] ... [arg.n:val.n]')
        print()
        print('Runs training according to the given configuration')
        print('CLI arguments override those in the configuration.')
        sys.exit(0)

    # load config file
    cfg_path = 'configs/baseline.yaml'
    if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
        cfg_path = sys.argv[1]
    with open(cfg_path, 'r') as file:
        config = yaml.safe_load(file)

    # apply overrides from command line
    for arg in sys.argv[2:]:
        k, v = arg.split(':')
        res = update_config_with_arg(config, k, v)
        if res is None:
            raise ValueError(f'could not understand property {k}')
        print(f'Updated value of {res[0]} from {res[1]} to {res[2]}')

    return config
