import os
from torch.autograd import Variable
import sys
import numpy as np
from math import floor
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
#import cv2
import math
import torchvision
from torchvision import models
import pdb

def uniform_quantize(k,maxval):
  class qfn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
      if k == 32:
        out = input
      elif k == 1:
        out = torch.sign(input)
      else:
        n = float(2 ** k - 1)
        sf = n/maxval
        out = torch.round(input * sf) / sf
      return out

    @staticmethod
    def backward(ctx, grad_output):
      grad_input = grad_output.clone()
      return grad_input

  return qfn().apply

class activation_quantize_fn(nn.Module):
  def __init__(self, a_bit, maxval):
    super(activation_quantize_fn, self).__init__()
    #assert a_bit <= 8 or a_bit == 32
    self.a_bit = a_bit
    self.maxval = maxval
    self.uniform_q = uniform_quantize(k=a_bit,maxval=maxval)

  def forward(self, x):
    if self.a_bit == 32:
      activation_q = x
    else:
      activation_q = self.uniform_q(torch.clamp(x, 0, self.maxval))
    return activation_q


class ResNet(nn.Module):

    def __init__(self,conf):
        super(ResNet, self).__init__()
        basenet = eval('models.'+conf.netname)(pretrained=conf.pretrained)
        self.bitwidth = conf.bits
        self.Quantizer =  activation_quantize_fn(self.bitwidth , 2 ** self.bitwidth - 1)
        self.dimension = conf.dimension
        self.conv3 = nn.Sequential(*list(basenet.children())[:-4])
        self.conv4 = list(basenet.children())[-4]
        self.midlevel = True
        self.isdetach = True
        if 'midlevel' in conf:
            self.midlevel = conf.midlevel
        if 'isdetach' in conf:
            self.isdetach = isdetacjh

        mid_dim = 1024
        feadim = 2048
        if conf.netname in ['resnet18','resnet34']:
            mid_dim = 256
            feadim = 512

        if self.midlevel:
            self.mcls = nn.Linear(mid_dim, conf.num_class)
            self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
            self.conv4_1 = nn.Sequential(nn.Conv2d(mid_dim, mid_dim, 1, 1), nn.ReLU())

#        self.recon = nn.Sequential(nn.Linear(self.dimension , 2 * self.dimension) , nn.ReLU() , nn.Linear(2 * self.dimension , self.dimension))
        self.conv5 = list(basenet.children())[-3]
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(self.dimension , 200)
        self.projector  = nn.Linear(feadim , self.dimension)
    def set_detach(self,isdetach=True):
        self.isdetach = isdetach

    def forward(self, x):
        x = self.conv3(x)
        conv4 = self.conv4(x)
        patches = conv4
        x = self.conv5(conv4)
        fea_pool = self.avg_pool(x).view(x.size(0), -1)
        fea_pool = self.projector(fea_pool)
#        print(fea_pool)
        fea_pool = self.Quantizer(fea_pool)
#        recon = self.recon(fea_pool)
        quantized = fea_pool
#        print(fea_pool)
#        exit(-1)
        #print(fea_pool.shape)
        logits = self.classifier(fea_pool)

        if self.midlevel:
            if self.isdetach:
                conv4_1 = conv4.detach()
            else:
                conv4_1 = conv4
            conv4_1 = self.conv4_1(conv4_1)
            pool4_1 = self.max_pool(conv4_1).view(conv4_1.size(0),-1)
            mlogits = self.mcls(pool4_1)
        else:
            mlogits = None

        return logits,x.detach(),mlogits , patches , quantized #, recon


    def _init_weight(self, block):
        for m in block.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def get_params(self, param_name):
        ftlayer_params = list(self.conv3.parameters()) +\
                           list(self.conv4.parameters()) +\
                           list(self.conv5.parameters())
        ftlayer_params_ids = list(map(id, ftlayer_params))
        freshlayer_params = filter(lambda p: id(p) not in ftlayer_params_ids, self.parameters())

        return eval(param_name+'_params')


def get_net(conf):
    return ResNet(conf)
