import numpy as np
from typing import Dict, List, Optional

import torch
import torch.nn as nn
from torch import Tensor

from .base import BaseOBCUtil
from ....utils.linalg import inv_sym


__all__ = ["StructOBCUtil"]

    
class StructOBCUtil(BaseOBCUtil):

    _supported_sparsity_types = ("structured")

    def __init__(
        self, 
        layer: nn.Module,
        rel_damp: float = 0.0,
        struct_size: int = 1,
        nan_to_num: bool = True,
    ) -> None:
        super().__init__(layer, rel_damp)
        self.struct_size = struct_size * self.struct_size_multiplier
        self.nan_to_num = nan_to_num
        self.sparse_weights: Dict[float, Tensor] = {}

    @property
    def struct_size_multiplier(self) -> int:
        if isinstance(self.layer, nn.Linear):
            return 1
        else:
            # for convolutional layer default struct size is the product of kernel sizes
            return np.prod(self.layer.kernel_size)

    @torch.no_grad()
    def prepare_data(self, struct_size = 1):
        w = self.weight.clone()
        # create mask of already pruned weights
        if struct_size > 1:
            mask = w.T.reshape(self.d_col // self.struct_size, -1).ne(0).any(dim=1)
            col_mask = mask.repeat_interleave(struct_size)
        else:
            mask = w.ne(0).any(dim=0)
            col_mask = mask
        # get number of zero structures
        num_zeros = (~mask).sum().item()
        H_inv = self.H.clone()
        # mask rows with zeroed weights
        H_inv.masked_fill_(~col_mask.unsqueeze(-1), 0)
        H_inv.masked_fill_(~col_mask.unsqueeze(-2), 0)
        H_inv.masked_fill_((~col_mask).diag_embed(dim1=-2, dim2=-1), 1)
        # invert
        H_inv = inv_sym(H_inv)
        return w, mask, H_inv, num_zeros
    
    def prepare_structured(self, sparsities: List[float]):
        if self.struct_size == 1:
            self.prepare_structured_col_single(sparsities)
        else:
            self.prepare_structured_col_multi(sparsities)

    @torch.no_grad()
    def prepare_structured_col_single(self, sparsities: List[float]):
        d_col, device, dtype = (
            self.d_col,
            self.weight.device,
            self.weight.dtype,
        )

        max_zero_col = round(max(sparsities) * d_col)
        zeros_to_sparsities = {
            round(sparsity * d_col): sparsity for sparsity in sparsities
        }

        w, mask, H_inv, num_zeros = self.prepare_data()
        # prepare losses TODO make useful
        self.losses = torch.zeros(len(sparsities), device=device, dtype=dtype)
        # if current sparsity is greater than the query, simply copy the weight
        for db_id, target_sparsity in zeros_to_sparsities.items():
            if db_id <= num_zeros:
                self.sparse_weights[target_sparsity] = self._reshape_to_orig_shape(w.clone(), 'structured')

        for col in range(num_zeros + 1, max_zero_col + 1):
            # 1) compure scores
            H_inv_d = torch.diag(H_inv)
            scores = (w ** 2 / H_inv_d).sum(dim=0)
            scores[~mask] = torch.inf
            # 2) mask selection
            p_id = scores.argmin(dim=0)
            mask[p_id] = False
            # 3) loss update
            self.losses += scores[p_id]
            # 4) weight update
            H_inv_pr = H_inv[p_id, :]
            H_inv_pd = H_inv_d[p_id]
            w.add_(H_inv_pr * (w[:, p_id] / H_inv_pd).unsqueeze(1), alpha=-1)
            w[:, ~mask] = 0
            if self.nan_to_num:
                w.nan_to_num_()
            # update weight database
            target_sparsity = zeros_to_sparsities.get(col, None)
            if target_sparsity is not None:
                self.sparse_weights[target_sparsity] = self._reshape_to_orig_shape(w.clone(), 'structured')               
            # 5) hessian update
            H_inv_pr.div_(H_inv_pd.sqrt())
            H_inv.addr_(H_inv_pr, H_inv_pr, alpha=-1)
        
    @torch.no_grad()
    def prepare_structured_col_multi(self, sparsities: List[float]):
        d_row, d_col, ss, device, dtype = (
            self.d_row,
            self.d_col,
            self.struct_size,
            self.weight.device,
            self.weight.dtype,
        )

        ns = d_col // ss
        s_ids = torch.arange(ns)
        row_ids = torch.arange(d_row)
        max_zero_str = round(max(sparsities) * ns)
        zeros_to_sparsities = {
            round(sparsity * ns): sparsity for sparsity in sparsities
        }

        w, mask, H_inv, num_zeros = self.prepare_data(ss)
        # prepare losses TODO make useful
        self.losses = torch.zeros(len(sparsities), device=device, dtype=dtype)
        # if current sparsity is greater than the query, simply copy the weight
        for db_id, target_sparsity in zeros_to_sparsities.items():
            if db_id <= num_zeros:
                self.sparse_weights[target_sparsity] = self._reshape_to_orig_shape(w.clone(), 'structured')

        for col in range(num_zeros + 1, max_zero_str + 1):
            # 1) compure scores
            H_inv_db = H_inv.view(ns, ss, ns, ss)[s_ids, :, s_ids, :]  # shape (ns, ss, ss)
            inv_H_inv_db = inv_sym(H_inv_db) # shape (ns, ss, ss)
            w_s = w.view(-1, ns, ss, 1) # shape (d_row, ns, ss, 1)
            inv_H_inv_db_w = inv_H_inv_db[None, :] @ w_s # shape (d_row, ns, ss, 1)
            scores = (w_s * inv_H_inv_db_w).sum(dim=(0, 2, 3))
            scores[~mask] = torch.inf
             # 2) mask selection
            p_id = scores.argmin(dim=0)
            p_ids = ss * p_id + torch.arange(ss, device=device)
            mask[p_id] = False
            if self.nan_to_num:
                w.nan_to_num_()
            # 3) loss update
            self.losses += scores[p_id]
            # 4) weight update
            inv_H_inv_pdb = inv_H_inv_db[p_id] # shape (ss, ss)
            w.addmm_(inv_H_inv_db_w[:, p_id, :, 0], H_inv[p_ids], alpha=-1)
            w[:, p_ids] = 0
            # update weight database
            target_sparsity = zeros_to_sparsities.get(col, None)
            if target_sparsity is not None:
                self.sparse_weights[target_sparsity] = self._reshape_to_orig_shape(w.clone(), 'structured')    
            # 5) hessian update
            H_inv.add(H_inv[:, p_ids] @ inv_H_inv_pdb @ H_inv[p_ids, :], alpha=-1)
            # isolate pruned columns
            H_inv[p_ids, :] = 0
            H_inv[:, p_ids] = 0 
            H_inv[p_ids, p_ids] = 1

    @torch.no_grad()
    def prune(self, sparsity_type: str, sparsities: List[float], **prune_kw) -> Dict[float, Tensor]:
        return super().prune(sparsity_type, sparsities, **prune_kw)

    def prune_structured(self, sparsities: List[float]):
        assert self.pre_step_completed
        self.prepare_structured(sparsities)
        return self.sparse_weights
    
    def reset(self):
        super().reset()
        self.sparse_weights = {}
    