import os
import pathlib
import random
import shutil
import time
import json
import numpy as np
import csv

import torch
import torch.nn as nn

from utils.conv_type import STRConv, STRConvER, ConvER, ConvMask
from utils.conv_type import sparseFunction
from utils.compensation import CompensatePrune
from utils.custom_activation import TrackActReLU

import networkx as nx

def get_sparsity(model):
    # finds the current density of the model and returns the density scalar value
    nz = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, ConvMask):
            nz += m.mask.sum()
            total += m.mask.numel()
    
    return nz / total

# The following function makes a configuration model for a bipartite graph
def make_config_model(p, ni, no):
    # p = density
    # ni, no = weight size
    N = ni * no
    nz = np.rint(p * ni * no)
    ko = int(nz / no)
    ki = int(nz / ni)
    print(nz, ni, ki, no, ko)
    if nz > no * ko:
        print('Diff', N, nz, ni*ki, no*ko)
        diff = int(nz - no * ko)
        # print('Diff', diff, no, ni)
        b = np.array(no * [ko])
        b[:diff] = b[:diff] + 1
    else:
        b = np.array(no * [ko])
        
    if ni * ki < nz:
        print('Diff', N, nz, ni*ki, no*ko)
        diff = int(nz - ni * ki)
        a = np.array(ni * [ki])
        a[:diff] = a[:diff] + 1
    else:
        a = np.array(ni * [ki])
        
    assert a.sum() == b.sum()
    G = nx.bipartite.havel_hakimi_graph(a, b, create_using=nx.Graph())
    A = nx.adjacency_matrix(G).todense()
    Al = A[len(a):, :len(a)]
    return torch.tensor(Al)


def make_regular_random_graph(model, args):
    total_params = 0
    l = 0
    sparsity_list = []
    for n, m in model.named_modules():
        if isinstance(m, (STRConvER, ConvER, ConvMask)):
            total_params += m.weight.numel()
            l += 1
    L = l
    X = args.er_sparse_init * total_params / L
    l = 0
    for n, m in model.named_modules():
        if isinstance(m, (STRConvER, ConvER, ConvMask)):
            if X / m.weight.numel() < 1.0:
                sparsity_list.append(X / m.weight.numel())
            else: 
                sparsity_list.append(1)
                # correction for taking care of exact sparsity
                diff = X - m.mask.numel()
                X = X + diff / (L - l)
            l += 1

    l = 0
    total = 0
    nz = 0
    # print('Length of the mask: ', L)
    for n, m in model.named_modules():
        # print(m.shape)
        if isinstance(m, ConvMask):
            p = sparsity_list[l]
            print('Layer: ', p, m.weight.shape)
            curr_mask = torch.ones_like(m.weight)
            if l < L-1:
                if p == 1:
                    curr_mask[:, :, :, :] = 1
                    total += curr_mask.numel()
                    nz += curr_mask.sum()

                else:
                    no, ni, k1, k2 = curr_mask.shape[0], curr_mask.shape[1], curr_mask.shape[2], curr_mask.shape[3]
                    for i in range(k1):
                        for j in range(k2):  
                            Al = make_config_model(p, ni, no)
                            curr_mask[:, :, i, j] = Al[:, :]
                    total += curr_mask.numel()
                    nz += curr_mask.sum()
                    # 
                    # print('Mask density: ', curr_mask.shape, (curr_mask.sum() / curr_mask.numel()).item())

            else:
                curr_mask[:, :, :, :] = 1
                total += curr_mask.numel()
                nz += curr_mask.sum()
            m.mask[:, :, :, :] = curr_mask
            l += 1
    print('The target sparsity of the mask achieved is: ', get_sparsity(model))    
    return model



def make_config_from_mask(model, args):
    base_dir = ''

    mask_name = base_dir + 'runs/' + args.ref_sparsity_mask_name
    mask_list = torch.load(mask_name)

    nz = 0
    total = 0
    l = 0
    L = len(mask_list)
    for n, m in model.named_modules():
        if isinstance(m, ConvMask):
            mask = mask_list[l]
            new_mask = torch.zeros_like(m.weight)
            no, ni, k1, k2 = mask.shape
            for i in range(k1):
                for j in range(k2):
                    indeg = mask[:, :, i, j].sum(dim=1).cpu()
                    outdeg = mask[:, :, i, j].sum(dim=0).cpu()
                    G = nx.bipartite.havel_hakimi_graph(indeg.numpy().astype(int), outdeg.numpy().astype(int), create_using=nx.Graph())
                    A = nx.adjacency_matrix(G).todense()
                    Al = torch.tensor(A[len(indeg):, :len(indeg)].T)
                    indeg_new = Al.sum(dim=1)
                    outdeg_new = Al.sum(dim=0)
                    new_mask[:, :, i, j] = Al
            l += 1
            m.mask = new_mask.to(m.weight.device)

    print('Modified the mask by making a configuration model of a loaded reference mask, degrees are same')
    return model

def prune_width_gradual(model, args):
    num_neurons = []
    idxs = []
    prune_ratio = args.prune_ratio
    print('identify neurons with nonzero indegrees and zero out {} fraction of them'.format(prune_ratio))
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask)) and not ('downsample' in n):
            print(n, m.weight.shape)
            nz_idx = torch.where(m.mask.sum(dim=(1, 2, 3)) != 0)[0]
            k = int(prune_ratio*m.mask.shape[0])
            idx = nz_idx[:k]
            
            print('Index layer: ', idx)
            idxs.append(idx)
    
    print('pruning these neurons to zero')

    l = 0
    total_num = 0
    total_den = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask)) and not ('downsample' in n):
            print('layer name: ', n)
            if l == 0:
                # does not prune the first and last layer and maintains their width if flag is set to True
                if not args.fix_first_last_structured:
                    idx = idxs[l]
                    m.mask[idx, :, :, :] = 0

            elif l == len(idxs)-1:
                if not args.fix_first_last_structured:
                    idx_out = idxs[l-1]
                    m.mask[:, idx_out, :, :] = 0
            
            else:
                idx = idxs[l]
                m.mask[idx, :, :, :] = 0
                idx_out = idxs[l-1]
                m.mask[:, idx_out, :, :] = 0

            l += 1
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()

    print('Pruned neuron in each layer and now the post pruning density is: ', total_num / total_den)
    return model

def prune_width_gradual_nonuniform_ratio(model, args):
    num_neurons = []
    idxs = []
    print('identify neurons with nonzero indegrees and zero out {} fraction of them'.format(args.prune_ratio))
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask)) and not ('downsample' in n):
            if cnt == 0:
                min_neuron = m.weight.shape[0]
            prune_ratio = args.prune_ratio * (m.weight.shape[0] / min_neuron)
            print(n, m.weight.shape)
            nz_idx = torch.where(m.mask.sum(dim=(1, 2, 3)) != 0)[0]
            k = int(prune_ratio*m.mask.shape[0])
            # num_keep = nz_idx - k
            idx = nz_idx[:k]
            print('Index layer: ', idx)
            idxs.append(idx)
            cnt += 1


    print('pruning these neurons to zero')
    l = 0
    total_num = 0
    total_den = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask)) and not ('downsample' in n):
            print('layer name: ', n)
            if l == 0:
                # does not prune the first and last layer and maintains their width if flag is set to True
                if not args.fix_first_last_structured:
                    idx = idxs[l]
                    m.mask[idx, :, :, :] = 0

            elif l == len(idxs)-1:
                if not args.fix_first_last_structured:
                    idx_out = idxs[l-1]
                    m.mask[:, idx_out, :, :] = 0
            
            else:
                idx = idxs[l]
                m.mask[idx, :, :, :] = 0
                idx_out = idxs[l-1]
                m.mask[:, idx_out, :, :] = 0

            l += 1
            total_num += (m.mask == 1).sum()
            total_den += m.mask.numel()

    print('Pruned neuron in each layer and now the post pruning density is: ', total_num / total_den)
    return model

        
# prune within a given target mask randomly
class PruneRandInMask:
    # Prune weights given a target mask
    def __init__(self, model, args):
        base_dir = ''

        mask_name = base_dir + 'runs/' + args.ref_sparsity_mask_name
        mask_list = torch.load(mask_name)

        self.prune_idx = {}
        l = 0
        nz = 0
        total = 0
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                mask = mask_list[l]
                self.prune_idx[n] = torch.stack(torch.where(mask == 1), dim=0)
                nz += mask.sum()
                total += mask.numel()
                l += 1
        print('Density of the target mask is: ', nz / total)

    def prune_random_balanced(self, model, density):
        total_params = 0
        l = 0
        sparsity_list = []
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask)):
                total_params += m.weight.numel()
                l += 1
        L = l
        X = density * total_params / l
        
        l = 0
        score_list = {}
        for n, m in model.named_modules():
            # torch.cat([torch.flatten(v) for v in self.scores.values()])
            if isinstance(m, (ConvMask)):
                curr_idx = self.prune_idx[n]
                score_list[n] = (m.mask.to(m.weight.device) * torch.rand_like(m.weight).to(m.weight.device)).detach().abs_()
                score_list[n][curr_idx[0, :], curr_idx[1, :], curr_idx[2, :], curr_idx[3, :]] = 5

                if X / m.weight.numel() < 1.0:
                    sparsity_list.append(X / m.weight.numel())
                else: 
                    sparsity_list.append(1)
                    # correction for taking care of exact sparsity
                    diff = X - m.mask.numel()
                    X = X + diff / (L - l)
                l += 1

        global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
        k = int((1 - density) * global_scores.numel())
        threshold, _ = torch.kthvalue(global_scores, k)

        if not k < 1:
            total_num = 0
            total_den = 0
            for n, m in model.named_modules():
                if isinstance(m, (ConvMask)):
                    score = score_list[n].to(m.weight.device)
                    zero = torch.tensor([0.]).to(m.weight.device)
                    one = torch.tensor([1.]).to(m.weight.device)
                    m.mask = torch.where(score <= threshold, zero, one)
                    total_num += (m.mask == 1).sum()
                    total_den += m.mask.numel()

        print('Overall model density after magnitude pruning at current iteration = ', total_num / total_den)
        print('Density after fixing the mask: ', get_sparsity(model))

        return model
