import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
import os
import subprocess
import pickle
import copy
from trainer.nets import FFNN
from google.cloud import storage


def normalize_pixelwise(images):
    orig_type= type(images)
    if orig_type==torch.Tensor: images= images.numpy()

    #=== compute the mean and stdev of each pixel across images
    pix_mean= np.mean(images, axis=0)
    pix_stdev= np.std(images, axis=0)

    #=== normalize images pixel-wise
    a= (images - pix_mean)
    images= np.divide(a, pix_stdev, out=np.zeros(a.shape, dtype=float), where=pix_stdev!=0)

    if orig_type==torch.Tensor: images= torch.Tensor(images)

    return images


    
def load_checkpoint(ckpt_dir, ckpt_name):
    full_ckpt_path=f'{ckpt_dir}/{ckpt_name}'
    exit_status=download_blob('bucket2', ckpt_dir, ckpt_name)
    if exit_status==0:
        ckpt = torch.load(full_ckpt_path)
    else:
        print("<!> Error: no model checkpoint to load from!")
    return ckpt




def compose_name_for_output(net_type, Nh_base, Nh, input_size, num_classes,
                            num_to_freeze_fc, num_to_freeze_cl,
                            make_linear, NTK_style, 
                            dataset, batch_size, learning_rate,
                            pre_dir, seed):

    NTK_tag='_NTK_style' if NTK_style else ''
    num_W_total_base = (input_size+num_classes)*Nh_base
    ctvt_total=(Nh_base/Nh) if Nh>0 else 1

    if net_type=='ffnn':

        ctvt = {'fc': 1-num_to_freeze_fc/(Nh*input_size),
                'cl': 1-num_to_freeze_cl/(Nh*num_classes)}

        num_in, num_out = get_arch_FFNN1(Nh, input_size, num_classes)
        arch_bits = [ f'{lkey}_{num_in[lkey]}-{num_out[lkey]}' for lkey in num_in.keys() ]
        arch = f'{num_W_total_base}_'+'_'.join(arch_bits)
        ctvt_bits = [ f'{lkey}_{ctvt[lkey]:.4f}' for lkey in ctvt.keys() ]
        ctvt_info = f'ctvt_{ctvt_total:.4f}_'+'_'.join(ctvt_bits)
        if make_linear==True:
            arch=f'lin_{arch}'
    else:
        arch= f'{num_W_total_base}_{input_size}_{num_classes}'
        ctvt_info= f'ctvt_{ctvt_total:.4f}'

    metrics_savedir= f'{pre_dir}/'
    metrics_savedir+= f'{dataset}_{net_type}{NTK_tag}_{arch}_{ctvt_info}'
    metrics_savedir+= f'_mbs_{batch_size}_lr_{learning_rate}_seed_{seed}'

    return metrics_savedir



def load_dataset(dataset, dataset_dir, batch_size, test_batch_size, 
                        data_normalization, no_da, kwargs):
    if dataset=='MNIST':
        train_loader = load_MNIST(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, normalization=data_normalization, kwargs=kwargs)
        train_loader_for_eval = load_MNIST(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, normalization=data_normalization, kwargs=kwargs)
        test_loader = load_MNIST(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, normalization=data_normalization, kwargs=kwargs)
        input_size = 28**2
        num_classes = 10
    elif dataset=='FashionMNIST':
        train_loader = load_FashionMNIST(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, normalization=data_normalization, kwargs=kwargs)
        train_loader_for_eval = load_FashionMNIST(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, normalization=data_normalization, kwargs=kwargs)
        test_loader = load_FashionMNIST(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, normalization=data_normalization, kwargs=kwargs)
        input_size = 28**2
        num_classes = 10
    elif dataset=='SVHN':
        train_loader = load_SVHN(split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_SVHN(split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_SVHN(split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10
    elif dataset=='CIFAR10':
        train_loader = load_CIFAR10(split='train', batch_size=batch_size, shuffle=True, no_da=no_da, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10(split='train', batch_size=test_batch_size, shuffle=False, no_da=no_da, kwargs=kwargs)
        test_loader = load_CIFAR10(split='test', batch_size=test_batch_size, shuffle=False, no_da=no_da, kwargs=kwargs)        
        input_size = 32**2*3
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        num_classes = 10
    elif dataset=='NonRedund':
        train_loader = load_NonRedund(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_NonRedund(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_NonRedund(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 8**2
        num_classes = 3
    elif dataset=='MNIST_labeled_by_FFNN_56':
        train_loader = load_MNIST_labeled_by_FFNN_56(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_MNIST_labeled_by_FFNN_56(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_MNIST_labeled_by_FFNN_56(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 28**2
        num_classes = 10
    elif dataset=='CIFAR10_labeled_by_FFNN_56_ReLU':
        train_loader = load_CIFAR10_labeled_by_FFNN_56_ReLU(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10_labeled_by_FFNN_56_ReLU(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_CIFAR10_labeled_by_FFNN_56_ReLU(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10
    elif dataset=='CIFAR10_labeled_by_FFNN_56_ReLU_hc':
        train_loader = load_CIFAR10_labeled_by_FFNN_56_ReLU_hc(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10_labeled_by_FFNN_56_ReLU_hc(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_CIFAR10_labeled_by_FFNN_56_ReLU_hc(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10 
    elif dataset=='CIFAR10_labeled_by_FFNN_5_ReLU':
        train_loader = load_CIFAR10_labeled_by_FFNN_5_ReLU(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10_labeled_by_FFNN_5_ReLU(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_CIFAR10_labeled_by_FFNN_5_ReLU(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10
    elif dataset=='CIFAR10_labeled_by_FFNN_5_Linear':
        train_loader = load_CIFAR10_labeled_by_FFNN_5_Linear(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10_labeled_by_FFNN_5_Linear(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_CIFAR10_labeled_by_FFNN_5_Linear(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10
    elif dataset=='CIFAR10_labeled_by_FFNN_5_ReLU_hc':
        train_loader = load_CIFAR10_labeled_by_FFNN_5_ReLU_hc(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10_labeled_by_FFNN_5_ReLU_hc(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_CIFAR10_labeled_by_FFNN_5_ReLU_hc(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10
    elif dataset=='CIFAR10_labeled_by_FFNN_5_Linear_hc':
        train_loader = load_CIFAR10_labeled_by_FFNN_5_Linear_hc(data_dir=dataset_dir, split='train', batch_size=batch_size, shuffle=True, kwargs=kwargs)
        train_loader_for_eval = load_CIFAR10_labeled_by_FFNN_5_Linear_hc(data_dir=dataset_dir, split='train', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        test_loader = load_CIFAR10_labeled_by_FFNN_5_Linear_hc(data_dir=dataset_dir, split='test', batch_size=test_batch_size, shuffle=False, kwargs=kwargs)
        input_size = 32**2*3
        num_classes = 10
    else:
        print('error: dataset not available!')

    return train_loader, train_loader_for_eval, test_loader, input_size, num_classes





def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def get_metrics_savedir(pre_dir,lkeys,num_to_freeze_fc,num_to_freeze_cl,
                        Nh_base,Nh,input_size,num_classes,learning_rate,seed):

    num_to_freeze={}
    num_to_freeze['fc']=num_to_freeze_fc
    num_to_freeze['cl']=num_to_freeze_cl

    ctvt = {}
    ctvt['fc'] = 1-num_to_freeze['fc']/(Nh*input_size)
    ctvt['cl'] = 1-num_to_freeze['cl']/(Nh*num_classes)
    ctvt_total = Nh_base / Nh
    num_W_total_base = (input_size+num_classes)*Nh_base

    num_in, num_out = get_arch_FFNN1(Nh, input_size, num_classes)
    arch_bits = [ f'{lkey}_{num_in[lkey]}-{num_out[lkey]}' for lkey in lkeys ]
    arch = f'{num_W_total_base}_'+'_'.join(arch_bits)
    ctvt_bits = [ f'{lkey}_{ctvt[lkey]:.4f}' for lkey in lkeys ]
    ctvt_info = f'ctvt_{ctvt_total:.4f}_'+'_'.join(ctvt_bits)

    metrics_savedir = f'{pre_dir}/MFE_{arch}_{ctvt_info}_lr_{learning_rate}_seed_{seed}'

    return metrics_savedir




def download_blob(bucket_name, source_blob_dir, source_blob_name):
    """Downloads a blob from the bucket."""
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    blob_full_path = f'{source_blob_dir}/{source_blob_name}'
    blob = bucket.blob(blob_full_path)
    if blob.exists():
      os.makedirs(source_blob_dir, exist_ok=True)
      blob.download_to_filename(blob_full_path)
      print(f'>> Downloaded blob {blob_full_path} to {blob_full_path}.')
      exit_status = 0
    else:
      print(f'<!> No checkpoint/data found in\n{blob_full_path}')
      exit_status = 1
    return exit_status







def make_lkey_to_name(names, to_sparsify):
    """ specific to this application """
    lkey_to_lname = {}
    for name in names:
        for lkey in to_sparsify:
            if lkey in name:
                lkey_to_lname[lkey] = name
    return lkey_to_lname


def get_fanin(lkey, dims, connectivity):
    """ compute fan-in and "bound" for param initialization for layer of type fc or conv """
    if 'conv' in lkey or 'downsample' in lkey or 'shortcut' in lkey:
        fan_in = dims[1]*dims[2]*dims[3]*connectivity # for convlayer
    elif 'fc' in lkey or 'cl' in lkey:
        fan_in = dims[1]*connectivity # for fclayer
    else:
        print('error: cant compute fan-in - unknown layer type!')
    bound = 1 / np.sqrt(fan_in)
    return fan_in, bound


def adjust_init(model, lkey_to_lname, connectivity):
    """ adjust initialization values for sparse layers 
        by taking the correct (reduced) fan-in into account """
    for lkey in lkey_to_lname.keys():
        name = lkey_to_lname[lkey]
        dims = model.state_dict()[f'{name}.0.weight'].shape
        fan_in, bound = get_fanin(lkey, dims, connectivity[lkey])
        model.state_dict()[f'{name}.0.weight'].data.uniform_(-bound, bound)
        model.state_dict()[f'{name}.0.bias'].data.uniform_(-bound, bound)



def adjust_init_ffnn(model, ctvt):
    """ adjust initialization values for sparse layers 
        by taking the correct (reduced) fan-in into account """

    for lkey in ['fc', 'cl']:
        dims = model.state_dict()[f'{lkey}.weight'].shape
        fan_in, bound = get_fanin(lkey, dims, ctvt[lkey])
        model.state_dict()[f'{lkey}.weight'].data.uniform_(-bound, bound)
        model.state_dict()[f'{lkey}.bias'].data.uniform_(-bound, bound)


def adjust_init_ffnn_NTK(model, ctvt):
    """ adjust initialization values for all layers to make it NTK style """

    for lkey in ['fc', 'cl']:
        dims = model.state_dict()[f'{lkey}.weight'].shape
        fan_in, bound = get_fanin(lkey, dims, ctvt[lkey])
        model.state_dict()[f'{lkey}.weight'].data.uniform_(-bound, bound)
        model.state_dict()[f'{lkey}.bias'].data.uniform_(-bound, bound)




def make_smask_and_sparsify_IO(model, lkey_to_lname, num_W_const, connectivity):
    """
        create smask: a tensor with 100*(1-connectivity)% values set to True 
        (indicating weights that are set to zero and frozen),
        and apply smask to weights (not biases);
        sparsify IO dim of CNN only
    """
    
    smask={}
    for lkey in lkey_to_lname.keys():
        name = lkey_to_lname[lkey]
        dims = model.state_dict()[f'{name}.0.weight'].shape

        # current number of weights (more than base case)
        num_W_current = np.prod(dims) 
        # smask: a tensor of same dims as target weight tensor, populated with vals drawn from uniform distrib
        smask[lkey] = torch.cuda.FloatTensor(dims).uniform_()
        # number of weights to freeze
        num_to_freeze = int( num_W_current-num_W_const[lkey] )
        # set top num_to_freeze values in smask to 1 -- these are the values in W to be set to zero and frozen
        


        r = torch.topk(smask[lkey].view(-1), num_to_freeze)
        smask[lkey] = torch.cuda.FloatTensor(dims).fill_(0) # re-create smask and fill with 0
        # now put 1 where the top num_to_freeze values were
        for i, v in zip(r.indices, r.values):
            index = i.item()
            i_col = index%dims[-1]
            i_row = index//dims[-1]
            smask[lkey][i_row, i_col] = 1

        smask[lkey] = smask[lkey].to(bool)
        s=torch.sum(smask[lkey]).item()
        p=100*s/np.prod(dims)
        print(f'applying smask to layer {lkey} - freezing {s} of {np.prod(dims)} weights ({p:.2f}%)')


        smask[lkey] = torch.cuda.FloatTensor(dims).uniform_() > connectivity
        s=torch.sum(smask[lkey]).item()
        p=100*s/np.prod(dims)
        #print(f'smask for {lkey} removes {s} vals, i.e., {p}')
        print(f'applying smask to layer {lkey} - freezing {s} of {np.prod(dims)} weights ({p:.2f}%)')
        with torch.no_grad():
            model.state_dict()[f'{name}.0.weight'][ smask[lkey] ] = 0

    return smask


def make_smask_and_sparsify(model, lkey_to_lname, connectivity):
    """
        create smask: a tensor with 100*(1-connectivity)% values set to True 
        (indicating weights that are set to zero and frozen),
        and apply smask to weights (not biases)
    """
    
    smask={}
    for lkey in lkey_to_lname.keys():
        name = lkey_to_lname[lkey]
        dims = model.state_dict()[f'{name}.0.weight'].shape
        smask[lkey] = torch.cuda.FloatTensor(dims).uniform_() > connectivity
        s=torch.sum(smask[lkey]).item()
        p=100*s/np.prod(dims)
        #print(f'smask for {lkey} removes {s} vals, i.e., {p}')
        print(f'applying smask to layer {lkey} - freezing {s} of {np.prod(dims)} weights ({p:.2f}%)')
        with torch.no_grad():
            model.state_dict()[f'{name}.0.weight'][ smask[lkey] ] = 0

    return smask



def get_num_W(model, ltypes, num_W):
    """ fills the dict num_W with number of weights for each relevant layer in the model """
    
    for lname, child in model.named_children():
        ltype = child._get_name()
        #print(lname, ltype)
        
        # lname is the user-defined name of the module
        # ltype is the pytorch module type, e.g. Sequential, MaxPool2d, Conv2d etc.

        if ltype in ltypes:
            #print(f'>>> {ltype} in ltypes, num_W = {child.weight.numel()}')
            # get the num of weights
            num_W[lname] = child.weight.numel()
        else:
            num_W[lname] = {}
            # go one level higher: get the children of the child
            get_num_W(child, ltypes, num_W[lname])

            

from collections import MutableMapping

def flatten(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)



import bisect
def find_ge(a, x):
    'Find leftmost item greater than or equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a):
        #ind = np.where(a==a[i])[0][-1]
        return a[i], i
    raise ValueError


def get_lname_for_statedict(lname):
    " transform lname into a form used to address the weight tensro of the module through model_statedict "
    lname_bits = lname.split('_')
    lname_bits.append('weight')
    lname_for_statedict = '.'.join(lname_bits)

    return lname_for_statedict

    
def compute_ntf_for_each_layer(num_to_freeze_tot, num_W, lnames_sorted):
    """ returns array num_to_freeze containing ntf for each layer, sorted as in lnames_sorted """
    
    num_layers = len(lnames_sorted)
    num_to_freeze = np.zeros(num_layers, dtype=int)

    # a) from num_W, compute the limits that determine over how many layers 
    # num_to_freeze_tot is distributed

    # (i)
    # differences in num_W between sorted layers
    num_W_diffs = np.diff([num_W[lname] for lname in lnames_sorted])
    num_W_diffs = [abs(d) for d in num_W_diffs]

    # (ii)
    # aux vector for the following dot product
    aux_vect = np.arange( 1,len(num_W_diffs)+1 )

    # (iii) 
    # the bins: array of max number of weights that can be frozen within the given layers before 
    # the next-lower layer (lower on the hierarchy of num_W) gets involved into sparsification
    ntf_lims = [np.dot(aux_vect[:k], num_W_diffs[:k]) for k in range(1,num_layers)]

    # (iv)
    # find in which bin num_to_freeze_tot falls - this gives you the number of layers to sparsify
    lim_val, lim_ind = find_ge(ntf_lims, num_to_freeze_tot)
    num_layers_to_sparsify = lim_ind+1

    # (v)
    # base fill: chunks of num_W that are frozen before the rest is distributed evenly
    base_fill = [sum(num_W_diffs[lind:lim_ind]) for lind in range(lim_ind)]
    base_fill.append(0)

    # (vi)
    # the rest that is distributed evenly over all layers that are sparsified
    rest_tot = num_to_freeze_tot-sum(base_fill)
    rest = int(np.floor(rest_tot/num_layers_to_sparsify))
    rest_mismatch = rest_tot-rest*num_layers_to_sparsify

    num_to_freeze[:num_layers_to_sparsify] = np.array(base_fill)+rest
    num_to_freeze[0] += rest_mismatch 
    # first layer gets the few additional frozen weights when rest_tot is not evenly divisible 
    # by num_layers_to_sparsify

    # check that all is done right
    assert sum(num_to_freeze)==num_to_freeze_tot, "(!) error: num_to_freeze not correct! "

    return num_to_freeze





def load_MNIST(data_dir, split, batch_size, shuffle, normalization, kwargs):
    """ Load and preprocess MNIST data, return data loader. """
    train_flag = True if split=='train' else False
    if normalization=='proper':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    elif normalization=='proper2':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1309,), (0.5528,))])
    elif normalization=='0505':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    elif normalization=='none': # no extra normalization
        transform = transforms.Compose([transforms.ToTensor()])
    elif normalization=='pixelwise': # no extra normalization
        transform = transforms.Compose([transforms.ToTensor()])
        print('>>> NOTE: pixelwise normalization does "none" normalization in data loader and assumes you normalize yourself in main ! <<<')
    else:
        print('specified normalization for the dataset is not valid!')

    dataset = datasets.MNIST(data_dir, train=train_flag, download=True, transform=transform)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    return data_loader


def load_FashionMNIST(data_dir, split, batch_size, shuffle, normalization, kwargs):
    """ Load and preprocess FashionMNIST data, return data loader. """
    train_flag = True if split=='train' else False
    if normalization=='proper':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2862,), (0.3299,))])
    elif normalization=='0505':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    elif normalization=='none': # no extra normalization
        transform = transforms.Compose([transforms.ToTensor()])
    else:
        print('specified normalization for the dataset is not valid!')
    dataset = datasets.FashionMNIST(data_dir, train=train_flag, download=True, transform=transform)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    return data_loader


def load_SVHN(split, batch_size, shuffle, kwargs):
    """ Load and preprocess SVHN data, return data loader. """
    transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), 
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    dataset = datasets.SVHN('./data/SVHN', split=split, download=True, transform=transform)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    return data_loader


def load_CIFAR10(split, batch_size, shuffle, no_da, kwargs):
    """ Load and preprocess CIFAR10 data, return data loader. """
    train_flag = True if split=='train' else False
    # transform from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py
    transform = {}
    if no_da:
        transform['train'] = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])
    else:
        transform['train'] = transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])

    transform['test'] = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])

    dataset = datasets.CIFAR10(root='./data/CIFAR10', train=train_flag, download=True, transform=transform[split])
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    return data_loader





def load_CIFAR100(split, batch_size, shuffle, kwargs):
    """ Load and preprocess CIFAR100 data, return data loader. """
    train_flag = True if split=='train' else False
    # transform from ... (forgot to copy link)
    transform = {}
    transform['train'] = transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
                        ])

    transform['test'] = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
                        ])

    dataset = torchvision.datasets.CIFAR100(root='./data/CIFAR100', train=train_flag, download=True, transform=transform[split])
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

    return data_loader




def evaluate(model, data_loader, data_normalization, num0, input_size, device, criterion):
    """ Evaluate given model on given data, return prediction accuracy and loss. """
    model.eval()
    loss_sum, correct, total=0, 0, 0

    with torch.no_grad():
        for images, labels in data_loader:
            if 'CNN' in model.__class__.__name__ or 'ResNet' in model.__class__.__name__:
                images= images.to(device)
            elif 'FFNN' in model.__class__.__name__:
                images= images.reshape(-1, input_size).to(device)
            else:
                print('error in eval: do not know how to handle input data for this model class name!')
            if data_normalization=='pixelwise': images= normalize_pixelwise(images) 
            batch_size = images.size(0)
            if num0>0:
                inds_for_batch = torch.LongTensor([random.sample( range(input_size), num0 ) for _ in range(batch_size)])
                images[torch.arange(images.size(0)).unsqueeze(1), inds_for_batch] = 0
                assert torch.sum(images[2]==0)==num0, "error: images not masked properly in eval loop!"

            labels = labels.to(device)

            outputs = model(images)

            loss = criterion(outputs, labels)
            _, predicted = outputs.max(1)

            loss_sum += len(images)*loss.item()
            correct += (predicted==labels).sum().item()
            total += len(images)

    acc = correct/total
    loss = loss_sum/total

    return acc, loss




def save_checkpoint(state, savename, cp_to_bucket=True):
    """Save model checkpoint and copy to bucket."""
    
    bucketpath = 'gs://bucket2/checkpoints/'
    # create (sub)dirs if not yet existent
    subfolder = savename.split('/')[0]
    if not os.path.exists('checkpoints'):
        os.mkdir('checkpoints')
    if not os.path.exists(f'checkpoints/{subfolder}'):
        os.mkdir(f'checkpoints/{subfolder}')
    torch.save(state, f'checkpoints/{savename}.ckpt')

    #if is_best:
        #shutil.copyfile(f'checkpoints/{savename}.ckpt', f'{save_dir}/model_best.ckpt')
        #torch.save(state, f'checkpoints/model_best.ckpt')

    sshproc=None
    if cp_to_bucket:
        sshproc = subprocess.Popen([f'gsutil -o GSUtil:parallel_composite_upload_threshold=150M \
                                cp checkpoints/{savename}.ckpt {bucketpath}{savename}.ckpt'], 
                                shell=True, 
                                stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return sshproc


def save_data(save_dict, save_name):
    """Save pickle data and copy to bucket."""

    bucketpath = 'gs://bucket2/runs/'
    subfolder = save_name.split('/')[0]

    if not os.path.exists('runs'):
        os.mkdir('runs')
    if not os.path.exists(f'runs/{subfolder}'):
        os.mkdir(f'runs/{subfolder}')

    save_path = f'runs/{save_name}'
    with open(save_path, 'wb') as handle:
        pickle.dump(save_dict, handle)

    sshproc = subprocess.Popen([f'gsutil -o GSUtil:parallel_composite_upload_threshold=150M \
                                cp runs/{save_name} {bucketpath}{save_name}'], 
                                shell=True, 
                                stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return sshproc



def save_data_to_bucket(data, save_name):
    """Save data and copy to bucket."""

    bucketpath = 'gs://bucket2/runs/'
    subfolder = save_name.split('/')[0]

    if not os.path.exists('runs'):
        os.mkdir('runs')
    if not os.path.exists(f'runs/{subfolder}'):
        os.mkdir(f'runs/{subfolder}')

    save_path = f'runs/{save_name}'
    torch.save(data, save_path)

    sshproc = subprocess.Popen([f'gsutil -o GSUtil:parallel_composite_upload_threshold=150M \
                                cp runs/{save_name} {bucketpath}{save_name}'], 
                                shell=True, 
                                stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return sshproc






        


