from google.cloud import storage
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import random
import subprocess
import os
import time
import shutil
from datetime import datetime

from .. import models
from . import get_model


from collections import MutableMapping

def flatten(d, parent_key='', sep='_'):
  items = []
  for k, v in d.items():
    new_key = parent_key + sep + k if parent_key else k
    if isinstance(v, MutableMapping):
      items.extend(flatten(v, new_key, sep=sep).items())
    else:
      items.append((new_key, v))
  return dict(items)


def get_tensor_dims(model, ltypes):
  """ make dict tensor_dims containing the dimensions of each layer tensor """
  tensor_dims= {}
  fill_tensor_dims(model, ltypes, tensor_dims)
  tensor_dims= flatten(tensor_dims)
  return tensor_dims


def fill_tensor_dims(model, ltypes, tensor_dims):
  """ fills the dict tensor_dims with the dims of each layer of type specified in ltypes 
  in the given model """
    
  for lname, child in model.named_children():
    ltype = child._get_name()

    if ltype in ltypes:
      tensor_dims[lname] = list(child.weight.shape)
    else:
      tensor_dims[lname] = {}
      # go one level higher: get the children of the child
      fill_tensor_dims(child, ltypes, tensor_dims[lname])



def get_num_W(tensor_dims):
  """ returns dict num_W with number of weights for each layer """
  num_W={}
  for k, v in tensor_dims.items():
    num_W[k]= np.prod(v)
  return num_W



def get_num_W_tot(args, noc1, ltypes):
  """ computes the number of weights in the model that has width noc1 and is specified by args by building this model, getting the tensor dims of each relevant layer, and finally computing the total number of weight by summing over the products of dims of each tensor """

  _model = get_model(args, noc1)
  tensor_dims = get_tensor_dims(_model, ltypes)
  del _model # don't need the model any more

  num_W= get_num_W(tensor_dims)  # dict: key= layer name, value= num weights in layer
  num_W_tot= sum(num_W.values()) # total number of weights in model

  return num_W_tot


import bisect
def find_ge(a, x):
  """ Find leftmost item greater than or equal to x """
  i = bisect.bisect_left(a, x)
  if i != len(a):
      #ind = np.where(a==a[i])[0][-1]
    return a[i], i
  raise ValueError


def get_ntf(num_to_freeze_tot, num_W, tensor_dims, lnames_sorted, io_only):
  """ returns array num_to_freeze containing ntf for each layer, sorted as in lnames_sorted;
      note: in case of io_only=True, it is in general not possible to match num_to_freeze_tot exactly, 
      bc we are freezing weights in packs of kernel_size (for most conv layers kernel_size=9);
      therefore, there is going to be a mismatch - num_W_tot will deviate from num_W_tot_base 
      by up to kernel size
      calls find_ge """

  num_layers = len(lnames_sorted)
  num_to_freeze = np.zeros(num_layers, dtype=int) # init

  # a) from num_W, compute the limits that determine over how many layers 
  # num_to_freeze_tot is distributed

  # list of num_W in sorted order
  num_W_sorted_list = [num_W[lname] for lname in lnames_sorted]

  # (i)
  # differences in num_W between sorted layers
  num_W_diffs = np.diff(num_W_sorted_list)
  num_W_diffs = [abs(d) for d in num_W_diffs]

  # (ii)
  # aux vector for the following dot product
  aux_vect = np.arange( 1,len(num_W_diffs)+1 )

  # (iii) 
  # the bins: array of max number of weights that can be frozen within the given layer before 
  # the next-smaller layer gets involved into sparsification
  ntf_lims = [np.dot(aux_vect[:k], num_W_diffs[:k]) for k in range(1,num_layers)]

  # (iv)
  # find in which bin num_to_freeze_tot falls - this gives you the number of layers to sparsify
  lim_val, lim_ind = find_ge(ntf_lims, num_to_freeze_tot)
  num_layers_to_sparsify = lim_ind+1

  # (v)
  # base fill: chunks of num_W that are frozen before the rest is distributed evenly
  base_fill = [sum(num_W_diffs[lind:lim_ind]) for lind in range(lim_ind)]
  base_fill.append(0)

  # (vi)
  # the rest that is distributed evenly over all layers that are sparsified
  rest_tot = num_to_freeze_tot-sum(base_fill)
  rest = int(np.floor(rest_tot/num_layers_to_sparsify))
  num_to_freeze[:num_layers_to_sparsify] = np.array(base_fill)+rest

  # first layer gets the few additional frozen weights when rest_tot is not evenly divisible 
  # by num_layers_to_sparsify

  if io_only:
    ntf_sum=0
    for l_ind in range(num_layers_to_sparsify):
      lname=lnames_sorted[l_ind]
      kernel_size=np.prod(tensor_dims[lname][-2:]) if len(tensor_dims[lname])==4 else 1
      y=int(np.around(num_to_freeze[l_ind]/kernel_size))
      num_to_freeze[l_ind]=y*kernel_size
      ntf_sum+=(y*kernel_size)

    remainder = num_to_freeze_tot - ntf_sum
    remainder = distribute_remainder(remainder, num_layers_to_sparsify, num_to_freeze, lnames_sorted, tensor_dims)
    assert sum(num_to_freeze)+remainder==num_to_freeze_tot, f"(!) error: num_to_freeze+remainder ({sum(num_to_freeze)+remainder}) not as expected ({num_to_freeze_tot}) ! "
  else:
    rest_mismatch = rest_tot - rest*num_layers_to_sparsify
    num_to_freeze[0]+= rest_mismatch 
    assert sum(num_to_freeze)==num_to_freeze_tot, f"(!) error: num_to_freeze ({sum(num_to_freeze)}) not as expected ({num_to_freeze_tot}) ! "

  return num_to_freeze


def get_lname_for_statedict(lname):
  """ transform lname into a form used to address the weight tensor of the module through model_statedict """
  lname_bits= lname.split('_')
  lname_bits.append('weight')
  lname_for_statedict= '.'.join(lname_bits)
  return lname_for_statedict


def make_smask_for_layer(num_W_layer, num_to_freeze_layer, tensor_dims_layer, io_only):
  """ Creates a sparsity mask for a layer with given dims as a boolean torch tensor directly on GPU; the mask is generated randomly.
      io_only: if True, only the io dims (the first two dims of a conv layer) are sparsified,
      else the mask is uniform in all dimensions (i.e., all tensor_dims are sparsified) """
  if io_only:
    if len(tensor_dims_layer)==4:
      kernel_size= np.prod(tensor_dims_layer[-2:])
      num_W_layer= int(num_W_layer/kernel_size)
      num_to_freeze_layer= int(num_to_freeze_layer/kernel_size)
    tensor_dims_layer= tensor_dims_layer[:2]

  # randomly generate indices of tensor elements to freeze
  inds_1d_to_freeze= random.sample( range(num_W_layer), num_to_freeze_layer ) # 1d array
  smask_for_layer= torch.cuda.BoolTensor(num_W_layer).fill_(0)
  smask_for_layer[inds_1d_to_freeze] = True
  smask_for_layer= smask_for_layer.view(tensor_dims_layer) # reshape to tensor dims
  return smask_for_layer


def get_fanin(lkey, dims, connectivity):
  """ compute fan-in and "bound" for param initialization for layer of type fc (linear) or conv """
  if 'conv' in lkey or 'downsample' in lkey or 'shortcut' in lkey:
    fan_in = dims[1]*dims[2]*dims[3]*connectivity # for conv layer
  elif 'fc' in lkey or 'cl' in lkey or 'linear' in lkey:
    fan_in = dims[1]*connectivity # for fc layer
  else:
    print('* * * Error: can not compute fan-in - unknown layer type! * * *')
  bound = 1 / np.sqrt(fan_in)
  return fan_in, bound


def adjust_layer_init(net, lname, lname_for_statedict, tensor_dims_layer, num_to_freeze_layer, num_W_layer):
  """Adjust init values of weights and ggf biases in a sparse layer"""
  lctvt= 1-num_to_freeze_layer/num_W_layer
  fan_in, bound = get_fanin(lname, tensor_dims_layer, lctvt)

  net.state_dict()[lname_for_statedict].data.uniform_(-bound, bound)

  # bias
  lname_for_statedict_bias= lname_for_statedict.replace('weight', 'bias')
  if lname_for_statedict_bias in net.state_dict().keys():
    net.state_dict()[lname_for_statedict_bias].data.uniform_(-bound, bound)