from typing import List, Sequence

import numpy as np


def check_shape(x_shape: Sequence[int], allowed_shape: str) -> bool:
    # parse allowed shape
    allowed_shape: List[str] = [s.lower().strip() for s in allowed_shape.split(',')]
    if len(x_shape) != len(allowed_shape):
        raise ValueError(f'Inconsistent shape: x: {x_shape}, allowed_shape: {allowed_shape}')

    int_to_symbol = dict()
    symbol_to_ind = dict()
    for x_size_i, allowed_size_i in zip(x_shape, allowed_shape):
        if allowed_size_i.isdigit():
            if x_size_i != int(allowed_size_i):
                return False
        else:
            # if allowed_size_i is symbols e.g., 'n', 'm', etc.,
            # check the mapping from symbols to integers is bijective
            if x_size_i not in int_to_symbol:
                int_to_symbol.update({x_size_i: allowed_size_i})
            else:
                if int_to_symbol[x_size_i] != allowed_size_i:
                    return False

            if allowed_size_i not in symbol_to_ind:
                symbol_to_ind.update({allowed_size_i: x_size_i})
            else:
                if symbol_to_ind[allowed_size_i] != x_size_i:
                    return False
    return True


def min_max_norm(attr_map: np.ndarray) -> np.ndarray:
    min_val = np.min(attr_map)
    max_val = np.max(attr_map)
    return (attr_map - min_val) / (max_val - min_val + 1e-8)


def zero_max_norm(attr_map: np.ndarray) -> np.ndarray:
    max_val = np.max(attr_map)
    return attr_map / (max_val + 1e-8)


def norm_and_to_uint8(attr_map: np.ndarray, norm='abs_zero_max') -> np.ndarray:
    if norm == 'min_max':
        attr_map = min_max_norm(attr_map)
    elif norm == 'abs_min_max':
        attr_map = np.abs(attr_map)
        attr_map = min_max_norm(attr_map)
    elif norm == 'abs_zero_max':
        attr_map = np.abs(attr_map)
        attr_map = zero_max_norm(attr_map)
    else:
        raise NotImplementedError(f'normalization mode: {norm} is not supported.')

    attr_map = np.clip(attr_map * 255, a_min=0, a_max=255).astype(np.uint8)
    return attr_map
