import dataclasses
import numpy as np
import torch
import copy

from foundations import hparams
import models.base
from pruning import base
from pruning.mask import Mask


@dataclasses.dataclass
class PruningHparams(hparams.PruningHparams):
    pruning_fraction: float = 0.2
    pruning_layers_to_ignore: str = None

    _name = 'Hyperparameters for Sparse Global Pruning'
    _description = 'Hyperparameters that modify the way pruning occurs.'
    _pruning_fraction = 'The fraction of additional weights to prune from the network.'
    _layers_to_ignore = 'A comma-separated list of addititonal tensors that should not be pruned.'


class Strategy(base.Strategy):
    @staticmethod
    def get_pruning_hparams() -> type:
        return PruningHparams

    def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None, older_model1: models.base.Model=None, older_model2: models.base.Model=None, trained_model_5000=None, trained_model_10000=None, trained_model_15000=None, trained_model_20000=None, trained_model_25000=None):
                
        def combine_rows(t, block_size, block_budget, row_sharing, msg=None):
            t = torch.from_numpy(t)
            if (t.dim()==2):
                dim = 2
                N, C = t.shape   # N=10, C=64
            elif (t.dim()==4):
                dim = 4
                N, C, H, W = t.shape 
            else:
                raise ValueError('Wrong Shape for Input!')
                
            if (dim == 2):
                t = t.abs().reshape(int(N/2), 2, int(C/block_size), block_size).permute(0,2,1,3)  # (N,C) --> (N/2, 2, C/8, 8) --> (N/2, C/8, 2, 8)
                t = (t**2).sum(2)/2  # (N/2, C/8, 2, 8) --> (N/2, C/8, 8)
                t = t.reshape(int(N/2), C)
            elif (dim == 4):
                if (C==3):
                    block_size=3
                t = t.abs().permute(2,3,0,1).reshape(H*W, N, int(C/block_size), block_size)  # (N, C, H, W) --> (H, W, N, C) --> (H x W, N, C/8, 8)
                t = t.reshape(H*W, int(N/row_sharing), row_sharing, int(C/block_size), block_size).permute(0,1,3,2,4)  # (H x W, N, C/8, 8) --> (H x W, N/4, 4, C/8, 8) --> (H x W, N/4, C/8, 4, 8)
                t = (t**2).sum(3)/row_sharing  # (H x W, N/4, C/8, 4, 8) --> (H x W, N/4, C/8, 8)
                t = t.reshape(H*W, int(N/row_sharing), C).permute(1,2,0).reshape(int(N/row_sharing), C, H, W)  # (H x W, N/4, C/8, 8) --> (H x W, N/4, C)
                
            return t.numpy()               
                        
        def decombine_rows(t, block_size, block_budget, row_sharing):
            t = torch.from_numpy(t).float()
            if (t.dim()==2):
                dim = 2
                N, C = t.shape  
            elif (t.dim()==4):
                dim = 4
                N, C, H, W = t.shape
            else:
                raise ValueError('Wrong Shape for Input!')
            
            if (dim == 2):
                t = torch.repeat_interleave(t, 2, 0)  # (H x W, N, C/8, 8) --> (H x W, N*4, C/8, 8)
            elif (dim == 4):
                if (C==3):
                    block_size=3
                t = t.abs().permute(2,3,0,1).reshape(H*W, N, int(C/block_size), block_size)  # (N, C, H, W) --> (H, W, N, C) --> (H x W, N, C/8, 8)        
                t = torch.repeat_interleave(t, row_sharing, dim=1)
                t = t.reshape(H, W, N*row_sharing, C).permute(2,3,0,1)
            return t.numpy()  
        
        def predict_mask_with_history(wgt_matrix, wgt_matrix_old1, wgt_matrix_old2, curr_mask, block_mask, block_size, pre_pruning_thres=None, name=None):
            curr_mask = torch.from_numpy(curr_mask).float()
            block_mask = torch.from_numpy(block_mask).float()
            wgt_matrix = torch.from_numpy(wgt_matrix)
            wgt_matrix_old1 = torch.from_numpy(wgt_matrix_old1)
            wgt_matrix_old2 = torch.from_numpy(wgt_matrix_old2)
                        
            if (curr_mask.dim()==2):
                dim = 2
                C, N = curr_mask.shape
            elif (curr_mask.dim()==4):
                dim = 4
                N, C, H, W = curr_mask.shape
            else:
                raise ValueError('Wrong Shape for Input!')
            
            if (curr_mask.dim() == 2):
                curr_mask = curr_mask.abs().permute(1,0)   # (N,C)
                block_mask = block_mask.abs().permute(1,0)   # (N,C)
                wgt_matrix = wgt_matrix.abs().permute(1,0)   # (N,C)
                wgt_matrix_old1 = wgt_matrix_old1.abs().permute(1,0)   # (N,C)
                wgt_matrix_old2 = wgt_matrix_old2.abs().permute(1,0)   # (N,C)
            else:
                curr_mask = curr_mask.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                block_mask = block_mask.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                wgt_matrix = wgt_matrix.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                wgt_matrix_old1 = wgt_matrix_old1.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                wgt_matrix_old2 = wgt_matrix_old2.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                
            curr_mask = curr_mask.reshape(-1,block_size)   # (N*(C/block_size), block_size)
            block_mask = block_mask.reshape(-1,block_size)
            wgt_matrix = wgt_matrix.reshape(-1,block_size)   # (N*(C/block_size), block_size)
            wgt_matrix_old1 = wgt_matrix_old1.reshape(-1,block_size)   # (N*(C/block_size), block_size)
            wgt_matrix_old2 = wgt_matrix_old2.reshape(-1,block_size)   # (N*(C/block_size), block_size)
            result_mask = curr_mask

            for i in range(curr_mask.shape[0]):
                row_nonzero = wgt_matrix[i,:][wgt_matrix[i,:]>0]
                mean_weights = row_nonzero.mean()
                std_weights = row_nonzero.std()
                nor_weights = (wgt_matrix[i,:].abs()-mean_weights)/std_weights
                
                row_nonzero_old1 = wgt_matrix_old1[i,:][wgt_matrix_old1[i,:]>0]
                mean_weights_old1 = row_nonzero_old1.mean()
                std_weights_old1 = row_nonzero_old1.std()
                nor_weights_old1 = (wgt_matrix_old1[i,:].abs()-mean_weights_old1)/std_weights_old1
                
                row_nonzero_old2 = wgt_matrix_old2[i,:][wgt_matrix_old2[i,:]>0]
                mean_weights_old2 = row_nonzero_old2.mean()
                std_weights_old2 = row_nonzero_old2.std()
                nor_weights_old2 = (wgt_matrix_old2[i,:].abs()-mean_weights_old2)/std_weights_old2
                
                for j in range(block_size):
                    cc, cc_old1, cc_old2 = nor_weights[j], nor_weights_old1[j], nor_weights_old2[j]
                    thres = pre_pruning_thres
                    if (cc < thres) and (cc_old1 < thres) and (cc_old2 < thres) and (block_mask[i][j] < 1.) and (curr_mask[i][j] > 0):
                        result_mask[i][j] = 0.0
            
            
            if (dim == 2):
                result_mask = result_mask.reshape(N,C).permute(1,0)
            else:
                result_mask = result_mask.reshape(-1, C).reshape(N, W, H, C).permute(0, 3, 1, 2)            
            
            
            return result_mask.bool().numpy()
        
        def predict_mask(wgt_matrix, curr_mask, block_mask, block_size, pre_pruning_thres=None):
            curr_mask = torch.from_numpy(curr_mask).float()
            block_mask = torch.from_numpy(block_mask).float()
            wgt_matrix = torch.from_numpy(wgt_matrix)
            
            if (curr_mask.dim()==2):
                dim = 2
                C, N = curr_mask.shape
            elif (curr_mask.dim()==4):
                dim = 4
                N, C, H, W = curr_mask.shape
            else:
                raise ValueError('Wrong Shape for Input!')
            
            if (curr_mask.dim() == 2):
                curr_mask = curr_mask.abs().permute(1,0)   # (N,C)
                wgt_matrix = wgt_matrix.abs().permute(1,0)   # (N,C)
            else:
                curr_mask = curr_mask.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                wgt_matrix = wgt_matrix.abs().permute(0,2,3,1).reshape(-1, C)  # (N, W, H, C) --> (N x W x H, C)
                
            curr_mask = curr_mask.reshape(-1,block_size)   # (N*(C/block_size), block_size)
            block_mask = block_mask.reshape(-1,block_size)
            wgt_matrix = wgt_matrix.reshape(-1,block_size)   # (N*(C/block_size), block_size)
            result_mask = curr_mask

            for i in range(curr_mask.shape[0]):
                row_nonzero = wgt_matrix[i,:][wgt_matrix[i,:]>0]
                mean_weights = row_nonzero.mean()
                std_weights = row_nonzero.std()
                nor_weights = (wgt_matrix[i,:].abs()-mean_weights)/std_weights
                for j in range(block_size):
                    cc = nor_weights[j]
                    thres = pre_pruning_thres
                    if (cc < thres) and (block_mask[i][j] == 0) and (curr_mask[i][j] > 0):
                        jjj+=1
                        result_mask[i][j] = 0.0
            if (dim == 2):
                result_mask = result_mask.reshape(N,C).permute(1,0)
            else:
                result_mask = result_mask.reshape(-1, C).reshape(N, W, H, C).permute(0, 3, 1, 2)            
            
            return result_mask.bool().numpy()
        
        
        def find_mask(t, block_size, block_budget, row_sharing):
            
            t = torch.from_numpy(t)
            if (t.dim()==2):
                dim = 2
                N, C = t.shape   # N=10, C=64
            elif (t.dim()==4):
                dim = 4
                N, C, H, W = t.shape
            else:
                raise ValueError('Wrong Shape for Input!')
        
            if (t.dim()==2):
                dim = 2
                N, C = t.shape   # N=10, C=64
            elif (t.dim()==4):
                dim = 4
                N, C, H, W = t.shape 
            else:
                raise ValueError('Wrong Shape for Input!')
                
            if (dim == 2):
                t = t.abs().reshape(N, int(C/block_size), block_size)  # (N, C/8, 8)
                t = t.reshape(int(N/2), 2, int(C/block_size), block_size).permute(0,2,1,3)  # (N/4, C/8, 4, 8)       
                t = (t**2).sum(2)/2  # (N/4, C/8, 8)
                
                b = t.topk(block_size-block_budget, largest=False)[0][:,:,-1]  # find the smallest BLOCK_BUDGET elements
                b = torch.cat([b.unsqueeze(2)]*block_size, 2)  # (N/4, C/8, 8)
                m = (t > b).float()  # (N/4, C/8, 8)
                m = torch.repeat_interleave(m, 2, dim=0)
                m = m.reshape(N, C)
                
            elif (dim == 4):
                if (C==3):
                    block_size=3
                t = t.abs().permute(2,3,0,1).reshape(H*W, N, int(C/block_size), block_size)  # (N, C, H, W) --> (H, W, N, C) --> (H x W, N, C/8, 8)
                t = t.reshape(H*W, int(N/row_sharing), row_sharing, int(C/block_size), block_size).permute(0,1,3,2,4)  # (H x W, N, C/8, 8) --> (H x W, N/4, 4, C/8, 8) --> (        H x W, N/4, C/8, 4, 8)       
                t = (t**2).sum(3)/row_sharing  # (H x W, N/4, C/8, 4, 8) --> (H x W, N/4, C/8, 8)
                
                b = t.topk(block_size-block_budget, largest=False)[0][:,:,:,-1]  # find the smallest BLOCK_BUDGET elements
                b = torch.cat([b.unsqueeze(3)]*block_size, 3)  # (H x W, N/4, C/8, 8)
                m = (t > b).float()  # (H x W, N/4, C/8, 8)
                m = torch.repeat_interleave(m, row_sharing, dim=1)
                m = m.reshape(H, W, N, C).permute(2, 3, 0, 1)
        
            return m.numpy()
        
        
        def produce_diff(wgt1, wgt2):
            total_dist = 0. 
            total_ele = 0.
            for k in wgt1:
                w1 = torch.from_numpy(wgt1[k]).float()
                w2 = torch.from_numpy(wgt2[k]).float()
                total_dist += (w1 - w2).abs().sum()
                total_ele += w1.nelement()
            return (total_dist/total_ele).numpy()
                
        #######hyperparameters########
        row_sharing = 1   # no row sharing is allowed
        block_size = 8    # this is M in N:M sparsity
        block_budget = 2  # this is N in N:M sparsity
        predict_with_history = True     # whether to use the early stopping mechanism
        pruning_ratio_per_round = 0.2   # this is p in L-mIMP
        final_truncate_thres = 0.01     # this is the epsilon in L-mIMP
        pre_pruning_thres = -0.1        # this is the alpha in L-mIMP
        early_bird_thres = 0.1          # this is the beta in L-mIMP
        ##########################################
        
        current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()
        
        # Determine the number of weights that need to be pruned.
        number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()])
        weight_size = np.sum([v.size for v in current_mask.values()])
        number_of_weights_to_prune = np.ceil(pruning_ratio_per_round * (number_of_remaining_weights-(float(block_budget)/block_size)*weight_size)/row_sharing).astype(int)
        
        # Determine which layers can be pruned.
        prunable_tensors = set(trained_model.prunable_layer_names)
        if pruning_hparams.pruning_layers_to_ignore:
            prunable_tensors -= set(pruning_hparams.pruning_layers_to_ignore.split(','))

        # Get the model weights.
        saved_mask_old = None
        consumed_time = None
        counter = 0
        time_saved = 0.
        
        # load the masks of the previous two epochs
        if older_model1:
            weights_old1 = {k: v.clone().cpu().detach().numpy() for k, v in older_model1.state_dict().items() if k in prunable_tensors}
        if older_model2:
            weights_old2 = {k: v.clone().cpu().detach().numpy() for k, v in older_model2.state_dict().items() if k in prunable_tensors}
            
        for tried_model in [trained_model_5000, trained_model_10000, trained_model_15000, trained_model_20000, trained_model_25000, trained_model]: 
            
            counter += 1
            weights = {k: v.clone().cpu().detach().numpy() for k, v in tried_model.state_dict().items() if k in prunable_tensors}

            # filter out the top budget values in a block
            block_masks = {}
            for k, v in weights.items():
                
                block_masks[k] = find_mask(v, block_size, block_budget, row_sharing)  # mask==1 for the top "budget" largest values in the block
                
            # Create a vector of all the unpruned weights in the model.
            weight_vector = []

            for k, v in weights.items():              
                final_mask = current_mask[k] * (1-block_masks[k])
                curr_v = combine_rows(v, block_size, block_budget, row_sharing)
                curr_final_mask = combine_rows(final_mask, block_size, block_budget, row_sharing, 'ff')
                weight_vector.append(curr_v[curr_final_mask == 1])
            weight_vector = np.concatenate(weight_vector)
            if weight_vector.any():
                threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune]
            else:
                threshold = 0.0
                
            dummy = {}
            for i, (k, v) in enumerate(weights.items()):
                ##
                curr_v = combine_rows(v, block_size, block_budget, row_sharing)
                combined_current_mask = combine_rows(current_mask[k], block_size, block_budget, row_sharing)
                
                mask1 = np.where(np.abs(curr_v) > threshold, combined_current_mask, np.zeros_like(curr_v))
                mask2 = combine_rows(block_masks[k], block_size, block_budget, row_sharing)
                mask = ((mask1 + mask2) > 1e-7)
                
                if (mask.shape[0] == 4 and mask.shape[1] == 3):
                    block_size = 4   
                else:
                    block_size = 8
                    
                if predict_with_history: 
                    if older_model1 and older_model2:
                        mask_f = predict_mask_with_history(v, weights_old1[k], weights_old2[k], mask, mask2, block_size, pre_pruning_thres, k)
                        #mask_f = mask
                    else:
                        mask_f = mask
                else:
                    mask_f = predict_mask(v, mask, mask2, block_size, pre_pruning_thres)
                
                
                mask_curr = decombine_rows(mask_f, block_size, block_budget, row_sharing)
                
                dummy[k] = mask_curr                    
                block_size = 8
                
                
            if saved_mask_old:
                ## has to detect whether there is a significant change in mask
                dist = produce_diff(dummy, saved_mask_old)
                saved_mask_old = copy.deepcopy(dummy)
                if (dist < early_bird_thres):
                    time_saved = 1-counter * 3./90
                    break
            else:
                saved_mask_old = copy.deepcopy(dummy)
            
        
        ## for the final round we have to hard prune the rest weights
        num_nonzeros = 0
        num_weights = 0
        for k, v in weights.items():
            num_nonzeros += dummy[k].sum()
            num_weights += dummy[k].size

        ## for the final round we have to hard prune the rest weights
        ratio = num_nonzeros/num_weights
        print(ratio-block_budget/block_size)
        stop = False
        if ((ratio-block_budget/block_size) < final_truncate_thres): 
            stop = True
            for k, v in weights.items():
                dummy[k] = block_masks[k]
                
        new_mask = Mask(dummy)
                         
        for k in current_mask:
            if k not in new_mask:
                new_mask[k] = current_mask[k]
            
        return new_mask, time_saved, stop
