import torch 
import numpy as np 
from models.preactivate_resnet import *



model = ResNet18(num_classes = 10)


density=0.1
erk_power_scale=1.0

masks = {}
index = 0
for name, tensor in model.named_parameters():
    name_cur = name + '_' + str(index)
    index += 1
    if len(tensor.size()) ==4 or len(tensor.size()) ==2:
        masks[name_cur] = torch.zeros_like(tensor, dtype=torch.float32, requires_grad=False)

# for key in masks.keys():
#     print(key)

# print('initialize by fixed_ERK')
# total_params = 0
# for name, weight in masks.items():
#     total_params += weight.numel()
# is_epsilon_valid = False

# dense_layers = set()
# while not is_epsilon_valid:

#     divisor = 0
#     rhs = 0
#     raw_probabilities = {}
#     for name, mask in masks.items():
#         n_param = np.prod(mask.shape)
#         n_zeros = n_param * (1 - density)
#         n_ones = n_param * density

#         if name in dense_layers:
#             rhs -= n_zeros

#         else:
#             rhs += n_ones
#             raw_probabilities[name] = (
#                                                 np.sum(mask.shape) / np.prod(mask.shape)
#                                         ) ** erk_power_scale

#             divisor += raw_probabilities[name] * n_param

#     epsilon = rhs / divisor

#     max_prob = np.max(list(raw_probabilities.values()))
#     max_prob_one = max_prob * epsilon
#     if max_prob_one > 1:
#         is_epsilon_valid = False
#         for mask_name, mask_raw_prob in raw_probabilities.items():
#             if mask_raw_prob == max_prob:
#                 print(f"Sparsity of var:{mask_name} had to be set to 0.")
#                 dense_layers.add(mask_name)
#     else:
#         is_epsilon_valid = True

# density_dict = {}
# total_nonzero = 0.0
# # With the valid epsilon, we can set sparsities of the remaning layers.
# for name, mask in masks.items():
#     n_param = np.prod(mask.shape)
#     if name in dense_layers:
#         density_dict[name] = 1.0
#     else:
#         probability_one = epsilon * raw_probabilities[name]
#         density_dict[name] = probability_one
#     print(
#         f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}"
#     )
#     masks[name][:] = (torch.rand(mask.shape) < density_dict[name]).float().data

#     total_nonzero += density_dict[name] * mask.numel()
# print(f"Overall sparsity {total_nonzero / total_params}")




# sparsity = {}
# for key in masks.keys():
#     element = masks[key]
#     new_name = '.'.join(key.split('.')[:-1])
#     zeros = (element == 0).float().sum()
#     sparse = zeros / element.nelement()
#     sparsity[new_name] = sparse
#     print(new_name, sparse)

# torch.save(sparsity, 'erk.pt')


# masks = torch.load('masks_seed/snip-0.1-seed1-mask.pt', map_location='cpu')
# print(masks.keys())

# sparsity = {}
# for key in masks.keys():
#     element = masks[key]
#     new_name = '.'.join(key.split('.')[:-1])
#     zeros = (element == 0).float().sum()
#     sparse = zeros / element.nelement()
#     sparse = 0.9
#     sparsity[new_name] = sparse

#     print(new_name, sparse)

# torch.save(sparsity, 'uniform.pt')


snip = torch.load('snip.pt')
grasp = torch.load('grasp.pt')
synflow = torch.load('synflow.pt')
uniform = torch.load('uniform.pt')
igq_raw = torch.load('igq.pt')
erk = torch.load('erk.pt')


sparsity = {}
ii = 0
for key in masks.keys():
    element = masks[key]
    new_name = '.'.join(key.split('.')[:-1])
    zeros = (element == 0).float().sum()
    sparse = zeros / element.nelement()
    sparsity[new_name] = igq_raw[ii]
    ii+=1
    print(new_name, sparsity[new_name])

print(len(snip.keys()))
print(igq_raw.shape)



overall = {
    'SNIP': snip,
    'GraSP': grasp,
    'SynFlow': synflow,
    'Uniform': uniform,
    'IGQ': sparsity,
    'ERK': erk
}

for key in overall.keys():
    if key == 'IGQ': continue
    if key == 'Uniform': continue
    print(key)
    for layer in overall[key].keys():
        overall[key][layer] = overall[key][layer].item()


torch.save(overall, 'process/layer_wise_sparsity.pt')
