from typing import List, Dict, Optional, Iterable, Any, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn

from .utils.common import to
from .utils.input_collector import ForwardInterrupt, InputCollector
from .utils.model import select_layers, LINEAR_LAYERS

from .obc import OBC
from .fast_obc import FastOBC


OBC_CLASSES = {"OBC": OBC, "FastOBC": FastOBC}


class Pruner:

    def __init__(
        self, 
        model: nn.Module, 
        data_loader: Iterable,
        module_regex: str,
        weights_orig: Dict[str, Any] = {},
        pruning_method: str = "FastOBC",
        rel_damp: float = 1e-2,
        obc_util_kwargs: Dict[str, Any] = {},
        sequential: bool = False,
        device: torch.device = None,
        cpu_offload: bool = False,
        blocks: Optional[str] = None,
        pre_modules: List[str] = [],
        max_samples: Optional[int] = None,
    ):
        self.model = model
        self.data_loader = data_loader
        self.module_regex = module_regex
        self.weights_orig = weights_orig
        self.obc_class = OBC_CLASSES[pruning_method]
        self.sequential = sequential
        self.device = device
        self.cpu_offload = cpu_offload
        # keyword arguments for OBC util constructor
        self.rel_damp = rel_damp
        self.obc_util_kwargs = obc_util_kwargs
        # arguments for sequential pruner
        self.blocks = blocks
        self.pre_modules = pre_modules
        # limit on number of samples
        self.max_samples = max_samples

    def _get_blocks(self, blocks: str):
        return self.model.get_submodule(blocks)

    def _prepare_hooks_and_handles(self, layers) -> Tuple[Dict[str, Any]]:
        handles = {}
        hooks = {}
        for layer_name, layer in layers.items():

            def update_handle_hook(name):
                def _hook(_, inp, out):
                    handles[name].update(inp[0])

                return _hook

            handles[layer_name] = self.obc_class(
                layer, rel_damp=self.rel_damp, **self.obc_util_kwargs
            )
            hooks[layer_name] = layer.register_forward_hook(
                update_handle_hook(layer_name)
            )
        return handles, hooks

    def _prune_group(self, handles: List[Any], sparsity: float, alpha: float = 0) -> None:
        for handle_name, handle in handles.items():
            if alpha > 0:
                W_orig = self.weights_orig[handle_name].to(device=handle.W.device, dtype=handle.W.dtype)
                handle.gradient_step(W_orig, alpha)
            handle.prune(sparsity)

    @torch.no_grad()
    def prune_parallel(self, sparsity: float, alpha: float = 0.0) -> None:
        device = self.device or next(self.model.parameters()).device
        self.model = self.model.to(device)

        # find layers
        layers = select_layers(self.model, '', self.module_regex, LINEAR_LAYERS)
        handles, hooks = self._prepare_hooks_and_handles(layers)

        samples_collected = 0
        for (inp_args, inp_kwargs) in self.data_loader:
            if inp_args:
                batch_size = len(inp_args[0])
            elif inp_kwargs:
                batch_size = len(next(iter(inp_kwargs.values())))

            self.model(*to(inp_args, device=device), **to(inp_kwargs, device=device))
            samples_collected += batch_size
            if self.max_samples and samples_collected == self.max_samples:
                break

        for _, h in hooks.items():
            h.remove()

        self._prune_group(handles, sparsity, alpha)

    @torch.no_grad()
    def prune_sequential(self, sparsity: float, alpha: float = 0.0) -> None:
        assert self.blocks, "Blocks have to be defined"
        device = self.device or next(self.model.parameters()).device

        if hasattr(self.model, "config") and hasattr(self.model.config, 'use_cache'):
                use_cache = self.model.config.use_cache
                self.model.config.use_cache = False

        # get first stage blocks (either encoder or decoder)
        blocks_name = self.blocks
        pre_modules = self.pre_modules
        # prepare pre blocks modules
        blocks = self._get_blocks(blocks_name)
        blocks[0] = blocks[0].to(device)
        if self.cpu_offload:
            # load input embeddings or any other preprocessing step
            for module_name in pre_modules:
                module = self.model.get_submodule(module_name)
                module.to(device)

        ### Input preparation ###
        blocks[0] = InputCollector(blocks[0])
        samples_collected = 0
        for (inp_args, inp_kwargs) in self.data_loader:
            if inp_args:
                batch_size = len(inp_args[0])
            elif inp_kwargs:
                batch_size = len(next(iter(inp_kwargs.values())))
            try:
                self.model(
                    *to(inp_args, device=device),
                    **to(inp_kwargs, device=device),
                )
            except ForwardInterrupt:
                pass
            samples_collected += batch_size
            if self.max_samples and samples_collected == self.max_samples:
                break
        input_args = blocks[0].input_args
        input_kwargs = blocks[0].input_kwargs
        blocks[0] = blocks[0].module

        if self.cpu_offload:
            # offload input embeddings or any other preprocessing step
            for module_name in pre_modules:
                module = self.model.get_submodule(module_name)
                module.cpu()

        ### Encoder/Decoder pruning ###
        # for block_id, block in enumerate(blocks):
        progress = tqdm(enumerate(blocks))
        for block_id, block in progress:
            progress.set_description(f'Processing layer {block_id+1}/{len(blocks)}')
            # TODO change to logging
            # print(f"Processing {blocks_name} {block_id}/{len(blocks)}.")
            block = block.to(device)
            # get layer prefix to select layers only within the block
            layer_prefix = f'{blocks_name}.{block_id}.'
            layers = select_layers(self.model, layer_prefix, self.module_regex, LINEAR_LAYERS)
            handles, hooks = self._prepare_hooks_and_handles(layers)

            for inp_args, inp_kwargs in zip(input_args, input_kwargs):
                block(*inp_args, **inp_kwargs)

            for _, h in hooks.items():
                h.remove()

            self._prune_group(handles, sparsity, alpha)

            for inp_args, inp_kwargs in zip(input_args, input_kwargs):
                out = block(*inp_args, **inp_kwargs)
                if isinstance(out, (list, tuple)):
                    out = out[0]
                # change only first input argument
                if len(inp_args) > 0:
                    inp_args[0].data = out
                elif 'hidden_states' in inp_kwargs:
                    inp_kwargs['hidden_states'] = out
                else:
                    raise ValueError("Unsupported block input format.")

            if self.cpu_offload:
                block = block.cpu()

            # del handles
            # del hooks
            torch.cuda.empty_cache()

        if hasattr(self.model, "config") and hasattr(self.model.config, 'use_cache'):
            self.model.config.use_cache = use_cache

    def prune(self, sparsity: float, alpha: float = 0.0) -> None:
        if self.sequential:
            self.prune_sequential(sparsity, alpha)
        else:
            self.prune_parallel(sparsity, alpha)