import argparse
import torch
import torch.nn as nn
from torch.utils import data, model_zoo
from torchvision import models
import numpy as np
import pickle
from torch.autograd import Variable
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import sys, os
import os.path as osp
import matplotlib

import random

IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)


def get_arguments():
    parser = argparse.ArgumentParser(description="Weights Extractor")
    parser.add_argument("--num-workers", type=int, default=4,
                        help="number of workers for multithread dataloading.")
    parser.add_argument("--data-dir", type=str, default='/path/to/dataset',
                        help="Path to the directory containing the source dataset.")
    parser.add_argument("--num-classes", type=int, default=1000,
                        help="Number of classes to predict (including background).")
    parser.add_argument("--pretrained-ra", type=str, default='/path/to/checkpoints/',
                        help="Where to save snapshots of the model.")
    parser.add_argument("--gpu", type=int, default=0,
                        help="choose gpu device.")
    parser.add_argument("--start", type=int, default=0,
                        help="Number of epoch.")
    parser.add_argument("--end", type=int, default=None,
                        help="Number of epoch.")
    parser.add_argument("--length", type=int, default=None,
                        help="length of dataset.")
    return parser.parse_args()

class ExtractWeights(nn.Module):
    def __init__(self, submodule, extracted_layers, path, length):
        super(ExtractWeights, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
        self.path = path
        self.length = length
        for i in extracted_layers:
            feat_path = osp.join(path, i + '_wh')
            if not os.path.exists(feat_path):
                os.makedirs(feat_path)
            feat_path = osp.join(path, i + '_rc')
            if not os.path.exists(feat_path):
                os.makedirs(feat_path)

    def forward(self, x, image=''):
        outputs = []
        for name, module in self.submodule._modules.items():
            if name in self.extracted_layers:
                # print name
                subsubmodule = module._modules.items()
                y = subsubmodule[0][1](x)
                features = y.data.cpu().numpy()
                num = 2000000 * 512 / features.shape[1] / self.length
                if num > features.shape[2] ** 2:
                    num = features.shape[2] ** 2
                features = features.reshape(features.shape[1], -1).transpose(1, 0)
                features = random.sample(features, num)
                np.save(osp.join(self.path, name + '_wh', image), features)

                y = subsubmodule[1][1](y)
                features = y.data.cpu().numpy()
                features = features.reshape(features.shape[1], -1).transpose(1, 0)
                features = random.sample(features, num)
                np.save(osp.join(self.path, name + '_rc', image), features)

                x = x + y
                x = subsubmodule[2][1](x)
            else:
                x = module(x)
        return x