# from nested_dict import nested_dict
from __future__ import division
from __future__ import print_function

import io
import os
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import resnet_ilsvrc

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def getUniqueFileHandler(results_filename, ext='.pkl', mode='wb'):
    index = ''
    while True:
        if not os.path.isfile(results_filename + index + ext):
            return io.open(results_filename + index + ext, mode)
        else:
            if index:
                index = '(' + str(int(index[1:-1]) + 1) + ')'  # Append 1 to number in brackets
            else:
                index = '(1)'
            pass  # Go and try create file again


def get_arch(model_name):
    if 'resnet' in model_name:
        return 'resnet_atl'
    elif 'vgg' in model_name:
        return 'vgg_atl'
    else:
        return None

def _get_num_features(model):
    if model.startswith('resnet'):
        n = int(model[6:])
        if n in [18, 34, 50, 101, 152]:
            return [64, 64, 128, 256, 512]
        else:
            n = (n - 2) // 6
            return [16] * n + [32] * n + [64] * n
    elif model.startswith('vgg'):
        n = int(model[3:].split('_')[0])
        if n == 4:
            return [64, 128, 512]
        elif n in [9, 11, 13, 16, 19]:
            return [64, 128, 256, 512, 512]

    raise NotImplementedError


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def load_weights_to_flatresnet(target, pretrain_model='resnet18'):
    source = resnet_ilsvrc.__dict__[pretrain_model](pretrained=True)

    store_data = []
    t = 0
    for name, m in source.named_modules():
        if isinstance(m, nn.Conv2d):
            store_data.append(m.weight.data)
            t += 1

    element = 0
    for name, m in target.named_modules():
        if isinstance(m, nn.Conv2d) and 'parallel_' not in name:
            m.weight.data = torch.nn.Parameter(store_data[element].clone())
            element += 1

    element = 1
    for name, m in target.named_modules():
        if isinstance(m, nn.Conv2d) and 'parallel_' in name:
            if m.weight.data.size() != store_data[element].size():
                print('Error: Mismatch is assignments')
            m.weight.data = torch.nn.Parameter(store_data[element].clone())
            element += 1

    store_data = []
    store_data_bias = []
    store_data_rm = []
    store_data_rv = []
    for name, m in source.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            store_data.append(m.weight.data)
            store_data_bias.append(m.bias.data)
            store_data_rm.append(m.running_mean)
            store_data_rv.append(m.running_var)

    element = 0
    for name, m in target.named_modules():
        if isinstance(m, nn.BatchNorm2d) and 'parallel_' not in name:
            m.weight.data = torch.nn.Parameter(store_data[element].clone())
            m.bias.data = torch.nn.Parameter(store_data_bias[element].clone())
            m.running_var = store_data_rv[element].clone()
            m.running_mean = store_data_rm[element].clone()
            element += 1

    element = 1
    for name, m in target.named_modules():
        if isinstance(m, nn.BatchNorm2d) and 'parallel_' in name:
            if m.weight.data.size() != store_data[element].size():
                print('Error: Mismatch is assignments')
            m.weight.data = torch.nn.Parameter(store_data[element].clone())
            m.bias.data = torch.nn.Parameter(store_data_bias[element].clone())
            m.running_var = store_data_rv[element].clone()
            m.running_mean = store_data_rm[element].clone()
            element += 1

    return target


def interpolateFeatures(input_feature, target_size):
    src_size = input_feature.size(3)
    if target_size != src_size:
        input_feature = F.interpolate(input_feature, scale_factor=target_size / src_size, mode='bilinear')
    return input_feature


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def set_logging_config(logdir):
    logging.basicConfig(format="[%(asctime)s] [%(name)s] %(message)s",
                        level=logging.INFO,
                        handlers=[logging.FileHandler(os.path.join(logdir, 'log.txt')),
                                  logging.StreamHandler(os.sys.stdout)])
