from PIL import Image
from anonlibrary.loader.model import get_module
from anonlibrary.util.vecquantile import QuantileVector
from anonlibrary.visit_strategy import Flat, Leaves, BottomUp, TopDown
from multiprocessing import Pool, Manager
from tqdm.auto import tqdm
import numpy as np
import os
import pickle
import torch


def is_convolutional(batch):
    '''
    Checks if the batch has been produced
    by a convolutional layer
    '''
    return len(batch.shape) == 4


def record_activations(model,
                       modules_names,
                       dataset,
                       batch_size=128,
                       cache='',
                       gpu=False):
    """
    Stores the activations of the selected
    modules into a NumPy array.

    Parameters
    ----------
    model: torch.nn.Module
        The PyTorch model to analyze
    modules_names: list of str
        Names of the modules to analyze
    dataset: anonlibrary.loader.Dataset
        Dataset containing the inputs
    batch_size: int, optional
        Batch size for the forward pass
    cache: str, optional
        Path of the folder in which to
        eventually store the activations
    gpu: bool, optional
        Flag to handle GPU usage

    Returns
    -------
    activations: dict of array_like
        Dictionary mapping the module
        name to either a NumPy array
        or a memmap containing the
        activations per input and
        per unit
    """

    # fix batch size for the image loader
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    # retrieve modules from names
    modules = [get_module(model, name) for name in modules_names]

    # extract the max value activation for each image
    activations = [None] * len(modules)
    features_size = [None] * len(modules)

    # if the module has been already studied
    # loads its activations from file
    if cache:
        skip = True

        # Filenames
        size_files = [os.path.join(cache, "size_%s.npy" % feature_name)
                      for feature_name in modules_names]
        mmap_files = [os.path.join(cache, "act_%s.mmap" % feature_name)
                      for feature_name in modules_names]

        for i, (mmap_fn, size_fn) in enumerate(zip(mmap_files, size_files)):
            if os.path.exists(mmap_fn) and os.path.exists(size_fn):
                features_size[i] = np.load(size_fn)
                activations[i] = np.memmap(mmap_fn,
                                           dtype=float,
                                           mode='r',
                                           shape=tuple(features_size[i]))
            else:
                skip = False

        # All the activations are on disk
        if skip:
            return {module: activations[i]
                    for i, module in enumerate(modules_names)}

    # List containing the last activations
    # for each hooked module in the model
    features_blobs = []

    def hook_feature(module, input, output):
        '''
        This function is attached to a PyTorch module,
        and it is called at each forward pass.
        The results are appended to the features_blobs list
        '''
        features_blobs.append(output.data.cpu().numpy())

    # Register dissection hook
    hooks = [module.register_forward_hook(hook_feature) for module in modules]

    # Keep track of the status of the
    # model when it was called,
    # to return it as it was
    was_training = model.training
    model.eval()

    # batch iteration over the inputs
    num_batches = int((len(dataset) + batch_size - 1) / batch_size)
    first_batch = True
    batch_idx = 0
    for batch in tqdm(loader, total=num_batches):

        # Delete all previous blobs in the list
        del features_blobs[:]

        # Prepare input batch
        if gpu:
            batch = batch.cuda()

        # Forward pass of the input
        with torch.no_grad():
            _ = model.forward(batch)

        # The activations of a convolutional layer have shape
        #   (N, C_out, H_out, W_out).
        # The activations of a fully connected layer have shape
        #   (N, C_out)
        #
        # These are stored in the activations list.
        if first_batch:
            # Iterate over the modules
            for i, feat_batch in enumerate(features_blobs):
                # Size of the activation for the i-th module
                features_size[i] = (len(dataset), *feat_batch.shape[1:])

                # Initialize max features
                if cache:
                    np.save(size_files[i], np.array(features_size[i]))
                    activations[i] = np.memmap(mmap_files[i],
                                               dtype=float,
                                               mode='w+',
                                               shape=features_size[i])
                else:
                    activations[i] = np.zeros(features_size[i])

            # Do not repeat the initialization
            first_batch = False

        # Input images range
        start_idx = batch_idx*loader.batch_size
        end_idx = min((batch_idx+1)*loader.batch_size, len(dataset))
        for i, feat_batch in enumerate(features_blobs):
            activations[i][start_idx:end_idx] = feat_batch

        batch_idx += 1

    # Set the state train/eval
    # of the model as it was
    # before the call
    if was_training:
        model.train()

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return {module: activations[i]
            for i, module in enumerate(modules_names)}


def _compute_thresholds(activations, quantile, batch_size, seed, queue):

    # Eventually reload the memmape
    if isinstance(activations, tuple):
        activations = np.memmap(activations[0],
                                dtype=float,
                                mode='r',
                                shape=activations[1])

    quant = QuantileVector(depth=activations.shape[1], seed=seed)

    for i in range(0, activations.shape[0], batch_size):
        batch = activations[i:i + batch_size]
        # Convolutional batch must be reshaped
        if is_convolutional(batch):
            batch = np.transpose(batch, axes=(0, 2, 3, 1)) \
                      .reshape(-1, activations.shape[1])
        quant.add(batch)

        if queue:
            queue.put(1)

    if queue:
        queue.put(None)

    return quant.readout(1000)[:, int(1000 * (1-quantile)-1)]


def compute_thresholds(activations, quantile=5e-3, n_workers=0, batch_size=128,
                       cache='', seed=1, module_name=None):
    """
    Given the activations pattern, estimate
    the threshold to correctly compute the
    activation masks.

    Parameters
    ----------
    activations: dict
        Dictionary pointing to the
        array_like activations for
        each module.
    quantile: float
        Quantile to consider the
        activations over 1-quantile
    n_workers: int, optional
        Number of workers to parallelize
        the computation of the thresholds
    batch_size: int, optional
        Batch size for the forward pass
    cache: str, optional
        Path of the folder in which to
        eventually store the activations
    seed: int, optional
        Seed for reproducibility

    Returns
    -------
    thresholds: dict of array_like
        Dictionary mapping the module
        name to either a NumPy array
        or a memmap containing the
        threshold per unit
    """

    modules_names = list(activations.keys())

    # Load from file
    thresholds = {mod: None for mod in modules_names}

    if cache:
        for module in modules_names:
            qtpath = os.path.join(cache, 'thresholds_%s.npy' % module)
            try:
                thresholds[module] = np.load(qtpath)
            except FileNotFoundError:
                pass

    modules_names = [m for m in modules_names if thresholds[m] is None]

    if not modules_names:
        return thresholds

    n_workers = min(n_workers, len(modules_names))

    # Eventually init pool
    if n_workers > 0:
        pool = Pool(n_workers)
        manager = Manager()
        queue = manager.Queue()
    else:
        queue = None

    # Parameter per worker
    params = []
    for module in modules_names:
        params.append(((activations[module].filename,
                        activations[module].shape),
                      quantile, batch_size, seed, queue))

    if n_workers > 0:
        # Map
        map_result = pool.starmap_async(_compute_thresholds, params)

        # Total batches
        total = 0
        for module in modules_names:
            total += np.ceil((activations[module].shape[0])/batch_size)
        pbar = tqdm(total=total)

        # Read updates from the workers
        ended = 0
        while ended != n_workers:
            e = queue.get()
            if e is None:
                ended += 1
            else:
                pbar.update(e)
        pbar.close()

        partial = map_result.get()

        # Correctly close the pool
        pool.close()
        pool.join()
    else:
        partial = [_compute_thresholds(*p) for p in params]

    # Reduce results
    partial = {m: partial[i] for i, m in enumerate(modules_names)}
    thresholds = {**thresholds, **partial}

    if cache:
        for module in modules_names:
            qtpath = os.path.join(cache, 'thresholds_%s.npy' % module)
            np.save(qtpath, thresholds[module])

    return thresholds


def _compute_sigma(activations, annotations, concepts, thresholds, queue,
                   start, end, batch_size):

    # Load activations
    if isinstance(activations, tuple):
        activations = np.memmap(activations[0],
                                dtype=float,
                                mode='r',
                                shape=activations[1])

    # Eventually allocate arrays
    n_units = activations.shape[1]
    n_concepts = len(concepts)
    intersection = np.zeros((n_units, n_concepts))
    act_sum = np.zeros((n_units))
    cmask_sum = np.zeros((n_concepts))

    first = True

    for batch in annotations(start, end, batch_size):
        #       it may made sense to precompute
        #       in memory the activation masks
        #       assuming that all the images
        #       have the same image.shape
        #
        #       a_mask[uid] (|units|, image.shape)
        #
        if first:
            # TODO: tricks could be done to avoid
            #       useless computations, something
            #       like:
            #
            # if len(activations[0][0].shape):
            #     a_mask = np.full((n_units, *batch[0].shape), False)
            # else:
            #     a_mask = np.full((n_units), False)

            a_mask = np.full((n_units, *batch[0].shape), False)
            c_mask = np.full(batch[0].shape, False)
            first = False

        for image in batch:

            # Select units activated by the image
            valid_units = [unit for unit in range(n_units)
                           if activations[image.index][unit].max()
                           > thresholds[unit]]

            # Generate activation masks
            for unit in valid_units:

                # Retrieve activations for the given image
                tmp_a_mask = activations[image.index][unit]

                # Resize if convolutional
                if len(tmp_a_mask.shape):
                    tmp_a_mask = Image.fromarray(tmp_a_mask) \
                                           .resize(image.shape,
                                                   resample=Image.BILINEAR)
                # Create mask
                a_mask[unit] = tmp_a_mask > thresholds[unit]

                # Update \sum_x |M_u(x)|
                act_sum[unit] += np.count_nonzero(a_mask[unit])

            # Retrieve concepts in the image
            selected_concepts = image.select_concepts(concepts)

            for c_idx, concept in enumerate(concepts):

                if concept in selected_concepts:
                    # retrieve L_c(x)
                    c_mask = image.get_concept_mask(concept, c_mask)

                    # update \sum_x |L_c(x)|
                    cmask_sum[c_idx] += np.count_nonzero(c_mask)

                    # Update counters
                    for unit in valid_units:

                        # |M_u(x) && L_c(x)|
                        intersection[unit, c_idx] += np.count_nonzero(
                                      np.logical_and(a_mask[unit], c_mask))

        # Notify end of batch
        if queue:
            queue.put(1)

    if queue:
        queue.put(None)

    # |M_u(x) || L_c(x)|
    union = act_sum[:, None] + cmask_sum[None, :] - intersection

    return intersection, union, act_sum, cmask_sum


def compute_sigma(activations, annotations, thresholds, ontology, n_workers=1,
                  cache='', batch_size=32, visit='flat', style='niou',
                  keep_n=None, module_name=None):
    """
    Given the activations patterns and the
    thresholds, it computes the IOU metric
    to estimate the semantic alignment
    between units and concepts.

    Parameters
    ----------
    activations: dict or array_like
        Dictionary pointing to the
        array_like activations for
        each module.
    annotations: TODO
        Callable iterator to retrieve
        the annotations for each image
        in the dataset.
    thresholds: dict or array_like
        Dictionary pointing to the
        array_like thresholds for
        each module.
    ontology: Ontology
        Ontological structure of the
        concepts in the dataset
    n_workers: int, optional
        Number of workers for parallel
        computation of the IoU
    cache: str, optional
        Path of the folder in which to
        eventually store the activations
    batch_size: int, optional
        Batch size for the forward pass
    visit: str, optional
        Visit strategy to traverse the
        ontology. Admitted values are:
        flat, leaves, bottomup and
        topdown
    style: str, optional
        Determines what is returned by
        the method, admitted values are:
        iou and niou, for the normalized
        metric
    keep_n: int, optional
        How many of the best concepts to
        return for each unit. If None
        all IoU scores are returned
    module_name: str, optional
        Name of the module to analyze if
        activations and thresholds are
        dictionaries. If none, all of the
        modules are analyzed.

    Returns
    -------
    max_iou: dict of array_like
        Dictionary mapping the module
        name to either a NumPy array
        or a memmap containing the
        IoU values.
    max_iou_concepts: dict of array_like
        Dictionary mapping the module
        name to either a NumPy array
        or a memmap containing the
        concepts to the relative max_iou
    """

    # Compute the IoU for multiple modules
    if isinstance(activations, dict) and not module_name:
        assert isinstance(thresholds, dict)
        assert activations.keys() == thresholds.keys()

        def build_args(module):
            return (activations[module], annotations, thresholds[module],
                    ontology, n_workers, cache, batch_size, visit, style,
                    keep_n, module)

        return {mod: compute_sigma(*build_args(mod)) for mod in activations}

    # Visit strategy
    if visit == 'flat':
        strategy = Flat(ontology)
    elif visit == 'bottomup':
        strategy = BottomUp(ontology)
    elif visit == 'topdown':
        strategy = TopDown(ontology)
    elif visit == 'leaves':
        strategy = Leaves(ontology)
    else:
        raise NotImplementedError

    # Return style
    if style not in ['ml', 'iou', 'fullniou', 'niou', 'raw']:
        raise NotImplementedError

    # Persistency
    if cache and module_name:

        iou_path = os.path.join(cache, '%s_metric_%s.npy'
                                % (style, module_name))
        concepts_path = os.path.join(cache, '%s_concepts_%s.npy'
                                     % (style, module_name))

        # Try to load from file
        try:
            max_iou = np.load(iou_path)
            max_iou_concept = np.load(concepts_path)
            return max_iou, max_iou_concept
        except FileNotFoundError:
            pass

    # Units to analyze
    n_images = activations.shape[0]
    n_units = activations.shape[1]
    units = [unit for unit in range(n_units)]

    # The frontier contains the concepts
    # that are analyzed in the current run
    frontier = strategy.init_frontier()

    # Eventually keep IoU values for all concepts
    keep_n = keep_n if keep_n is not None else strategy.max_n

    # Keep track of the best concepts per unit
    max_iou = np.zeros((n_units, keep_n)) - 1
    max_iou_concept = np.zeros((n_units, keep_n)) - 1

    # Allocate partial arrays for parallelism
    if n_workers > 0:
        pool = Pool(n_workers)

    # Iterate while all the relevant
    # concepts are analyzed
    while frontier:

        # Cache leaves
        for node in frontier:
            node.cache_leaves()

        # Compute intersection and union arrays
        if n_workers > 0:
            psize = int(np.ceil(float(n_images) / n_workers))
            ranges = [(s, min(n_images, s + psize)) for s
                      in range(0, n_images, psize) if s < n_images]

            # Queue to handle progress
            manager = Manager()
            queue = manager.Queue()

            # Parameter per worker
            params = []
            for i, r in enumerate(ranges):
                params.append(((activations.filename, activations.shape),
                              annotations, frontier, thresholds, queue,
                              *r, batch_size))

            # Map
            map_result = pool.starmap_async(_compute_sigma, params)

            # Total batches
            total = 0
            for r in ranges:
                total += np.ceil((r[1]-r[0])/batch_size)
            pbar = tqdm(total=total)

            # Read updates from the workers
            ended = 0
            while ended != n_workers:
                e = queue.get()
                if e is None:
                    ended += 1
                else:
                    pbar.update(e)

            pbar.close()

            partial = map_result.get()

            # Reduce
            intersection = np.sum([e[0] for e in partial], axis=0)
            union = np.sum([e[1] for e in partial], axis=0)
            act_sum = np.sum([e[2] for e in partial], axis=0)
            cmask_sum = np.sum([e[3] for e in partial], axis=0)
        else:
            # Total batches
            total = np.ceil(n_images / batch_size)

            def wrap_annotations(a, b, c):
                return tqdm(annotations(a, b, c), total=total)

            results = _compute_sigma(activations, wrap_annotations, frontier,
                                     thresholds, None, 0, len(annotations),
                                     batch_size)
            intersection, union, act_sum, cmask_sum = results

        # Compute IoU and find the best keep_n results
        if style == 'ml':
            iou = intersection / (cmask_sum[None, :] + 1e-12)
        else:
            iou = intersection / (union + 1e-12)
            if style == 'niou':
                iou = iou / (cmask_sum[None, :] + 1e-12)
            elif style == 'fullniou':
                iou = iou * (act_sum[:, None] / (cmask_sum[None, :] + 1e-12))

        # full_iou € |Units| x (|Frontier|+keep_n)
        full_iou = np.concatenate((iou, max_iou), axis=1)

        # max_iou_idx € |Units| x keep_n
        # Contains the sorted indexes of the best
        # keep_n concepts, relative to full_iou
        # with the highest IoU scores.
        max_iou_idx = np.argsort(full_iou, axis=1)[:, -keep_n:]

        # Given the index of the concept in full_iou
        # it retrieves the actual Concept object either
        # from the frontier or from the previous list
        def retrieve_concept(c):
            if c < len(frontier):
                return frontier[c].id
            else:
                return max_iou_concept[unit_id][c-len(frontier)]

        # Update the structures containing
        # the best keep_n IoU scores per unit
        for unit_id in units:
            max_iou[unit_id, :] = full_iou[unit_id, max_iou_idx[unit_id, :]]
            max_iou_concept[unit_id] = np.array([retrieve_concept(c) for c
                                                 in max_iou_idx[unit_id, :]])

        # sorted_max_iou € |Units| x keep_n
        # The last column [;,-1] contains the
        # maximum IoU value for each unit
        sorted_max_iou = np.sort(max_iou, axis=1)

        # Update the frontier according to the strategy
        frontier = strategy.update_frontier(intersection, union, act_sum,
                                            sorted_max_iou[:, -1])

    # Store the results
    if cache and module_name:
        np.save(iou_path, max_iou)
        np.save(concepts_path, max_iou_concept)

    # Correctly close the pool
    if n_workers > 0:
        pool.close()
        pool.join()

    if style == 'raw':
        # NOTE: this won't work with topdown or bottomup styles
        return intersection, union, act_sum, cmask_sum
    else:
        return max_iou, max_iou_concept


def filter_lineage(iou, concepts, ontology):
    '''
    iou and concepts are two lists
    where iou[i] is the IoU value of the
    concept with id concepts[i].

    this function filters out concepts
    which IoU value is derived by
    its ancestors or by its descendants.
    '''

    if isinstance(iou, np.ndarray):
        iou = list(iou)

    if isinstance(concepts, np.ndarray):
        concepts = list(concepts)

    # Base case
    if len(concepts) == 0:
        return iou, concepts

    # Get first concept with the maximum IoU
    max_iou = max(iou)
    first = iou.index(max_iou)

    # All concepts in concepts[first:]
    # have the same IoU value, we select
    # the shallowest one
    min_depth = None
    min_cid = concepts[first]
    for cid in concepts[first:]:
        c = ontology.nodes[cid]
        if min_depth is None or c.depth < min_depth:
            min_depth = c.depth
            min_cid = cid

    # Now all the concepts that are in the lineage
    # of the best one should be removed
    lineage = ontology.nodes[min_cid].lineage
    pairs = [(v, c) for v, c in zip(iou, concepts) if c not in lineage]
    partial_iou = [e[0] for e in pairs]
    partial_concept = [e[1] for e in pairs]

    partial_iou, partial_concept = filter_lineage(partial_iou, partial_concept,
                                                  ontology)

    return partial_iou + [max_iou], partial_concept + [min_cid]


def ids_to_concepts(concepts, ontology):
    return [[(v, ontology.nodes[c]) for v, c in unit] for unit in concepts]


def retrieve_concepts(iou, ontology, module_name=None, cache='', filter=False):
    """
    Given the IOU metric it returns the
    best concepts for each unit.

    Parameters
    ----------
    iou: dict or array_like
        Dictionary pointing to the
        array_like IOU values for
        each module.
    ontology: Ontology
        Ontological structure of the
        concepts in the dataset
    module_name: str, optional
        Name of the module to analyze if
        activations and thresholds are
        dictionaries. If none, all of the
        modules are analyzed.
    cache: str, optional
        Path of the folder in which to
        eventually store the results

    Returns
    -------
    concepts: dict of array_like
        Dictionary mapping the module
        name to a list of (Concept,
        IOU) tuples for each unit
    """
    # Select the best concepts for multiple modules
    if isinstance(iou, dict) and not module_name:

        def build_args(module):
            return (iou, ontology, module, cache)

        return {mod: retrieve_concepts(*build_args(mod)) for mod in iou}

    # Persistency
    if cache and module_name:

        file_path = os.path.join(cache, 'rediou_%s.pkl' % module_name)

        # Try to load from file
        try:
            with open(file_path, 'rb') as fp:
                concepts = pickle.load(fp)
                concepts = ids_to_concepts(concepts, ontology)
                return concepts
        except FileNotFoundError:
            pass

    iou, concepts = iou[module_name]
    n_units = iou.shape[0]

    # Filter out the lineage of the best concepts
    if filter:
        concepts = [filter_lineage(iou[u], concepts[u], ontology)
                    for u in tqdm(range(n_units))]
    else:
        concepts = [(list(iou[u]), list(concepts[u]))
                    for u in tqdm(range(n_units))]

    # Turn into tuples
    concepts = [list(zip(*concepts[u])) for u in range(n_units)]

    # Store to file
    if cache and module_name:
        with open(file_path, 'wb+') as fp:
            pickle.dump(concepts, fp)

    # Get Concept objects
    concepts = ids_to_concepts(concepts, ontology)

    return concepts


def filter_concepts(concepts, tau=0, quantile=None, module_name=None):
    """
    Filters the concepts of a given
    module by evaluating their IoU
    value

    Parameters
    ----------
    concepts: dict of list
        Dictionary mapping the module
        name to a list of (Concept,
        IOU) tuples for each unit
    tau: float, optional
        Threshold to test against
        the IoU metric of each
        concept
    quantile: float, optional
        If defined it overwrites the
        tau threshold by estimating
        it using the provided quantile
        over the values distribution
        of the module
    module_name: str, optional
        Name of the module to analyze if
        activations and thresholds are
        dictionaries. If none, all of the
        modules are analyzed.

    Returns
    -------
    concepts: dict of list
        Dictionary mapping the module
        name to a list of (Concept,
        IOU) tuples for each unit
    """
    # Select the best concepts for multiple modules
    if isinstance(concepts, dict) and not module_name:

        def build_args(module):
            return (concepts, tau, quantile, module)

        modules = concepts.keys()
        results = [filter_concepts(*build_args(mod)) for mod in concepts]
        concepts = {mod: e[0] for mod, e in zip(modules, results)}
        tau = {mod: e[1] for mod, e in zip(modules, results)}

        return concepts, tau

    concepts = concepts[module_name]
    n_units = len(concepts)

    # Eventually estimate the threshold
    if quantile is not None:
        values = sum([[v for v, c in u] for u in concepts], [])
        tau = np.quantile(values, quantile)

    # Filter according to the threshold
    concepts = [[(v, c) for v, c in concepts[u] if v > tau]
                for u in range(n_units)]

    return concepts, tau
