import os
import time
import torch
import shutil

from torch import Tensor
from torch.nn import Module
from typing import List, Optional

from .pruning_handle_base import *


__all__ = [
    "ExactOBCPruningHandle",
    "FastOBCPruningHandle"
]


class ExactOBCPruningHandle(BasePruningHandle):

    def __init__(self, 
        layer: Module, 
        storage_dir: Optional[str] = None,
        # OBC params
        blocks_in_parallel: Optional[int] = None,
        store_traces_on_drive: bool = False,
        store_database_on_drive: bool = False,
    ) -> None:
        if store_traces_on_drive or store_database_on_drive:
            assert storage_dir is not None
        super().__init__(layer, storage_dir)
        # device where the weight is stored during the OBC step
        self._device = None
        # traces and losses during the OBC step
        self._traces = None
        self._losses = None
        # by default set to output dim
        self._blocks_in_parallel = blocks_in_parallel or self._dim_out
        self._store_traces_on_drive = store_traces_on_drive
        self._store_database_on_drive = store_database_on_drive
        # losses 
        self._reconstruction_losses = {}
        # loss hessian
        self._H = None

    @torch.no_grad()
    def prepare(self, H: Tensor):
        '''
        :param H: hessian for the given layer
        '''
        self._H = H
        # set the device to H.device
        self._device = H.device
        # flatten the weight in case of len(weight.shape) > 2
        self._weight = self._weight.to(self._device).view(self._weight_shape[0], -1)
        # if the entire input channel is 0 -> channel is dead and doesn't contribute
        M_dead_channel = torch.diag(self._H) == 0
        self._H[M_dead_channel, M_dead_channel] = 1
        self._weight[:, M_dead_channel] = 0

    @torch.no_grad()
    def _prepare_block(self, row_start, row_end):
        # make a copy of weight being pruned
        W_block = self._weight[row_start: row_end, :].clone()
        # create mask of already pruned weights
        M_block = (W_block != 0)
        # get minimal number of zeros in block
        min_zeros_block = (~M_block).sum(dim=1).min().item()
        # create N copies (d_in, d_in) -> (P, dim_in, dim_in)
        H_inv_block = self._H.unsqueeze(0).repeat((row_end - row_start, 1, 1))
        # mask rows with zeroed weights
        H_inv_block.masked_fill_(~M_block.unsqueeze(-1), 0)
        H_inv_block.masked_fill_(~M_block.unsqueeze(-2), 0)
        H_inv_block.masked_fill_(torch.diag_embed(~M_block, dim1=-2, dim2=-1), 1.0)
        # invert
        H_inv_block = torch.cholesky_inverse(torch.linalg.cholesky(H_inv_block))

        return W_block, M_block, H_inv_block, min_zeros_block

    def _extract_from_traces(self, traces: Tensor, losses: Tensor, sparsity_level: float):
        _, topk_indices = torch.topk(
            losses.reshape(-1), 
            k=int((1 - sparsity_level) * self._dim_in * self._dim_out)
        )
        # mask with 0 for pruned weights and 1 else where
        sparsity_mask = torch.zeros(self._dim_in * self._dim_out, dtype=torch.bool)
        # all weights can be pruned
        if len(topk_indices) > 0:
            sparsity_mask[topk_indices] = 1
        # reshape mask to the weight shape
        sparsity_mask = sparsity_mask.reshape(self._weight_shape)
        # count number of zeros per row
        zeros_per_row = (~sparsity_mask).sum(dim=1)
        return (
            # weight for the given sparsity level
            traces[zeros_per_row, torch.arange(self._dim_out)],
            # reconstruction loss || W x - \hat{W} x ||_2^2
            (~sparsity_mask * losses.cpu()).max(dim=1)[0].sum().item()
        )

    @property
    def losses(self):
        assert self._is_built, "Losses are not prepared. Run prepare() and build() first."
        assert self._losses is not None, "Database was called with list of sparsities"
        return self._losses    

    def get_reconstruction_loss(self, sparsity: float):
        # check whether the weight with given sparsity is present in the database
        if sparsity not in self._sparsity_levels:
            raise ValueError("This sparsity is not present in the database.")
        return self._reconstruction_losses[sparsity]

    @torch.no_grad()
    def build(self, sparsity_levels: Optional[List[float]] = None):
        super().build()
        # traversed pruning traces
        traces = torch.zeros(
            (self._dim_in + 1, self._dim_out, self._dim_in), 
            device='cpu', 
            dtype=self._weight.dtype
        )
        # accumulated losses
        losses = torch.zeros(
            (self._dim_out, self._dim_in), 
            device=self._weight.device, 
            dtype=self._weight.dtype
        )
    
        for row_start in range(0, self._dim_out, self._blocks_in_parallel):
            row_end = min(row_start + self._blocks_in_parallel, self._dim_out)
            # size of given block
            block_size = row_end - row_start
            block_ids = torch.arange(block_size)
            # prepare block for obc step
            W_block, M_block, H_inv_block, min_zeros_block = self._prepare_block(row_start, row_end)
            # create traces for a given block
            traces_block = torch.zeros(
                (self._dim_in + 1, block_size, self._dim_in), 
                device=self._weight.device, 
                dtype=self._weight.dtype
            )
            traces_block[:(min_zeros_block + 1)] = W_block
            # accumulated losses for a given block
            accum_loss = torch.zeros(block_size, device=self._weight.device)

            for col in range(min_zeros_block + 1, self._dim_in + 1):
                H_inv_block_d = torch.diagonal(H_inv_block, dim1=1, dim2=2)
                scores = W_block ** 2 / H_inv_block_d
                scores[~M_block] = torch.inf
                # scores and ids of pruned columns
                p_scores, p = torch.min(scores, dim=1)
                # update accumulated loss for the given row
                accum_loss += 0.5 * p_scores
                losses[torch.arange(row_start, row_end), p] = accum_loss
                H_inv_block_p = H_inv_block[block_ids, p, :]
                H_inv_block_pp = H_inv_block_d[block_ids, p]
                W_block -= H_inv_block_p * (W_block[block_ids, p] / H_inv_block_pp).unsqueeze(1)
                M_block[block_ids, p] = 0
                W_block[~M_block] = 0
                traces_block[col, :, :] = W_block
                # do not update H_inv on the last iteration
                if col == self._dim_in:
                    break
                H_inv_block_p /= torch.sqrt(H_inv_block_pp).unsqueeze(1)
                H_inv_block -= torch.bmm(H_inv_block_p.unsqueeze(2), H_inv_block_p.unsqueeze(1))

            traces[:, row_start: row_end, :] = traces_block.cpu()

        '''
        If list of sparsity levels is provided: create database
        with the sparsities from the list provided.
        Otherwise, keep pruning traces and losses for the later queries.
        '''
        if sparsity_levels is not None:
            self._sparsity_levels = sparsity_levels
            # construct database of sparsities for a given list of sparsities
            for sparsity_level in self._sparsity_levels:
                rec_weight, rec_loss = self._extract_from_traces(traces, losses, sparsity_level)
                self._reconstruction_database[sparsity_level] = rec_weight
                self._reconstruction_losses[sparsity_level] = rec_loss

            # free memory allocated for traces
            del traces
            del losses
        else:
            self._traces = traces
            self._losses = losses

        # offload database to drive to save memory
        if self._store_database_on_drive:
            for sparsity_level in self._sparsity_levels:
                torch.save(
                    self._reconstruction_database[sparsity_level], 
                    os.path.join(self._storage_dir, f'weight_sparsity={sparsity_level}.pth')
                )
                # set to None
                self._reconstruction_database[sparsity_level] = None
        # offload traces to drive to save memory (if they are needed for latter)
        if self._traces is not None and self._store_traces_on_drive:
            torch.save(traces, os.path.join(self._storage_dir, 'traces.pth'))
            self._traces = None

        # empty CUDA cache
        torch.cuda.empty_cache()

        super().prepare()
        
    @torch.no_grad()
    def set(self, sparsity: float) -> float:
        # if weight database was constructed in build step
        if len(self._reconstruction_database) > 0 :
            # check whether the weight with given sparsity is present in the database
            if sparsity not in self._sparsity_levels:
                raise ValueError("This sparsity is not present in the database.")

            if self._store_database_on_drive:
                weight = torch.load(
                    os.path.join(self._storage_dir, f'weight_sparsity={sparsity}.pth')
                )
            else:
                weight = self._reconstruction_database[sparsity]
        else:
            assert self._traces is not None
            # load traces from drive in case they are stored there
            if self._store_traces_on_drive:
                self._traces = torch.load(os.path.join(self._storage_dir, 'traces.pth'))
            # get weight and loss for the queried sparsity
            weight, _ = self._extract_from_traces(self._traces, self._losses, sparsity)
            # reset traces to save memory
            self._traces = None
        # set weight in the handle
        self._layer.weight.data = \
            weight.to(self._orig_device).reshape(self._weight_shape)
        self._weight = self._layer.weight
        
    def free(self):
        # cleanup hessian
        self._H = None
        # cleanup losses and traces
        self._losses = None
        self._traces = None
        # reset reconstruction losses
        self._sparsity_levels = None
        self._reconstruction_losses = None
        super().free()
        

class FastOBCPruningHandle(BasePruningHandle):

    def __init__(self, 
        layer: Module, 
        storage_dir: Optional[str] = None,
        # OBC params
        block_size: Optional[int] = None,
        store_database_on_drive: bool = False,
    ) -> None:
        if store_database_on_drive:
            assert storage_dir is not None
        super().__init__(layer, storage_dir)
        # device where the weight is stored during the OBC step
        self._device = None
        # traces and losses during the OBC step
        self._traces = None
        self._losses = None
        # by default set to output dim
        self._block_size = block_size or  self._dim_in
        self._store_database_on_drive = store_database_on_drive
        # losses 
        self._reconstruction_losses = {}
        # loss hessian
        self._H = None

    @torch.no_grad()
    def prepare(self, H: Tensor):
        '''
        :param H: hessian for the given layer
        '''

        # set the device to H.device
        self._device = H.device
        # flatten the weight in case of len(weight.shape) > 2
        self._weight = self._weight.to(self._device).view(self._weight_shape[0], -1)
        # if the entire input channel is 0 -> channel is dead and doesn't contribute
        M_dead_channel = torch.diag(H) == 0
        H[M_dead_channel, M_dead_channel] = 1
        self._weight[:, M_dead_channel] = 0
        # invert hessian matrix
        H_inv = torch.cholesky_inverse(torch.linalg.cholesky(H))
        # compute cholesky decompostion of the inverse and save
        H_chol = torch.linalg.cholesky(H_inv, upper=True)
        self._H_chol = H_chol

    @torch.no_grad()
    def _prepare_block(self, W_new, col_start, col_end):
        # create copy of weight to allow for repeated update
        W_block = W_new[:, col_start: col_end].clone()
        H_chol_block = self._H_chol[col_start: col_end, col_start: col_end]
        
        return W_block, H_chol_block

    @torch.no_grad()
    def build(self, sparsity_levels: Optional[List[float]] = None):
        super().build()
        # TODO add version supporting arbitrary sparsity
        assert sparsity_levels is not None, "This method supports only querying a particular sparsity"
        # set sparsity levels
        self._sparsity_levels = sparsity_levels
        # TODO can one reuse the results of pruning from previous sparsities?
        # construct the database for given list of sparsitities
        for sparsity_level in sparsity_levels:
            # init new weight as copy of the original
            W_new = self._weight.clone()
            # reconstruction loss
            losses = torch.zeros_like(self._weight)
            for col_start in range(0, self._dim_in, self._block_size):
                col_end = min(col_start + self._block_size, self._dim_in)
                # size of given block
                block_size = col_end - col_start
                # prepare block
                W_block, H_chol_block = self._prepare_block(W_new, col_start, col_end)
                # new weight after updates
                W_new_block = torch.zeros_like(W_block)
                # error 
                E_block = torch.zeros_like(W_block)
                # losses in block
                L_block = torch.zeros_like(W_block)

                scores = (W_block / torch.diag(H_chol_block).unsqueeze(0)) ** 2
                threshold, _ = torch.kthvalue(
                    scores.view(-1), k=int(sparsity_level * scores.numel())
                )
                # mask specifying zero/nonzero weights
                M_block = scores <= threshold

                # TODO optimize if possible
                for i in range(block_size):
                    W_block_i = W_block[:, i]
                    H_chol_block_ii = H_chol_block[i, i]
                    # get masked weight
                    W_new_block_i = W_block_i.clone()
                    W_new_block_i[M_block[:, i]] = 0
                    W_new_block[:, i] = W_new_block_i
                    E_block_i = (W_block_i - W_new_block_i) / H_chol_block_ii
                    L_block[:, i] = 0.5 * E_block_i ** 2
                    # update weight
                    W_block[:, i:] -= E_block_i.unsqueeze(1).matmul(H_chol_block[i, i:].unsqueeze(0))
                    E_block[:, i] = E_block_i

                W_new[:, col_start: col_end] = W_new_block
                losses[:, col_start: col_end] = L_block

                W_new[:, col_end:] -= \
                    E_block.matmul(self._H_chol[col_start:col_end, col_end:])
                # clean-up
                del E_block, L_block, W_block, W_new_block, M_block, scores
                torch.cuda.empty_cache()
            
            torch.cuda.synchronize()
            # update database and reconstruction losses
            self._reconstruction_database[sparsity_level] = W_new.cpu()
            self._reconstruction_losses[sparsity_level] = losses.sum().item()

            # offload database to drive to save memory
            if self._store_database_on_drive:
                torch.save(
                    self._reconstruction_database[sparsity_level], 
                    os.path.join(self._storage_dir, f'weight_sparsity={sparsity_level}.pth')
                )
                # set to None to release memory
                self._reconstruction_database[sparsity_level] = None

    def get_reconstruction_loss(self, sparsity: float):
        # check whether the weight with given sparsity is present in the database
        if sparsity not in self._sparsity_levels:
            raise ValueError("This sparsity is not present in the database.")
        return self._reconstruction_losses[sparsity]

    @torch.no_grad()
    def set(self, sparsity: float):
        # check whether the weight with given sparsity is present in the database
        if sparsity not in self._sparsity_levels:
            raise ValueError("This sparsity is not present in the database.")

        if self._store_database_on_drive:
            weight = torch.load(
                os.path.join(self._storage_dir, f'weight_sparsity={sparsity}.pth')
            )
        else:
            weight = self._reconstruction_database[sparsity]
        # set weight in the handle
        self._layer.weight.data = \
            weight.to(self._orig_device).reshape(self._weight_shape)
        self._weight = self._layer.weight

    def free(self):
        # reset cholesky inverse of hessian
        self._H_chol = None
        self._sparsity_levels = None
        self._reconstruction_losses = None
        super().free()
        