'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import errno
import os
import sys
import time
import math
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable

__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'save_checkpoint', 'torch_accuracy', 'AverageMeter','get_vid_module_dict']


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)

def mkdir_p(path):
    '''make dir if not exist'''
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise

# Save checkpoint
def save_checkpoint(state, checkpoint, filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)


def torch_accuracy(output, target, topk=(1,)):
    topn = max(topk)
    batch_size = output.size(0)

    _, pred = output.topk(topn, 1, True, True)
    pred = pred.t()

    if len(target.size()) == 1:
        is_correct = pred.eq(target.view(1, -1).expand_as(pred))
    elif len(target.size()) == 2:
        is_correct = pred.eq(target.max(1)[1].expand_as(pred))
        
    ans = []
    for i in topk:
        is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True)
        ans.append(is_correct_i.mul_(100.0 / batch_size))

    return ans

def get_vid_module_dict(model, hook_layers):
    vid_module_dict = {} 
    hook_layers = [layer_name + '.weight'  for layer_name in hook_layers]
    layer_names = model._modules.keys()
    i = 0 
    for name, p in model.named_parameters():
        if name in hook_layers :
            name = 'feature_maps' + str(i) 
            channels = p.shape[0]
            mean, var = get_mean_and_variance(channels)
            vid_module_dict[name + '_mean'] = mean
            vid_module_dict[name + '_var'] = var
            i += 1

    return vid_module_dict

# get mean and variance of students 
def get_mean_and_variance(in_channels):
    out_channels = in_channels
    var_adap_avg_pool = False
    eps = 1e-5 
    # achieve mean module

    mean = get_adaptation_layer(in_channels, out_channels, False)
    # achieve std module 
    var = get_adaptation_layer(in_channels, out_channels, False)
    var.add_module(str(len(var)+1), nn.Softplus())
    
    return mean, var 

# adaptation layer 
def get_adaptation_layer(in_channels, out_channels, adap_avg_pool):
    layer = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                    kernel_size=1, stride=1, padding=0),
        nn.ReLU(),
        # nn.BatchNorm2d(in_channels), 
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                    kernel_size=1, stride=1, padding=0 )
    )
    
    return layer

class AverageMeter(object):
    name = 'No name'

    def __init__(self, name='No name'):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.mean = 0
        self.num = 0
        self.now = 0

    def update(self, mean_var, count=1):
        if math.isnan(mean_var):
            mean_var = 1e6
            print('Avgmeter getting Nan!')
        self.now = mean_var
        self.num += count

        self.sum += mean_var * count
        self.mean = float(self.sum) / self.num

