import torch
import numpy as np
import torch.nn.functional as F
from torch.nn import Parameter
import torchvision.transforms as transforms
import torch.optim as optim
import os
from . import misc
import types
from torch.autograd import Variable
from tqdm import tqdm


def init_control_gates(m):
    global relu_counter
    name = m.__class__.__name__
    if name.find('ReLU') != -1:
        m.control_gates = Parameter(torch.FloatTensor(relu_shapes[relu_counter].shape))
        m.control_gates.data.fill_(1.0)
        relu_counter += 1
        relu_counter = relu_counter % len(relu_shapes)

def reset_control_gates(m):
    name = m.__class__.__name__
    if name.find('ReLU') != -1:
        m.control_gates.data.fill_(1.0)
        m.control_gates.grad.data.fill_(0.0)

def new_forward(self, x):
    out = F.relu(x, self.inplace)
    out = self.control_gates * out
    return out

def replace(m):
    name = m.__class__.__name__
    if name.find('ReLU') != -1:
        m.forward = types.MethodType(new_forward, m)

def collect_control_gates(m):
    name = m.__class__.__name__
    if name.find('ReLU') != -1:
        control_gates.append(m.control_gates)


control_gates = []
relu_shapes = []
relu_counter = 0


def cdrp_sparsity(cgs):
    deads = 0
    n = 0
    for i, cg in enumerate(cgs):
#        print(i, cg.shape)
        cg_copy = torch.zeros(cg.shape)
#        print(i, 2)
#        print(torch.sum(cg_copy.flatten()))
        cg_copy[cg <= 0] = 1
#        print(torch.sum(cg_copy.flatten()))
#        print(i, 3)
        deads += torch.sum(cg_copy.flatten())
#        if i == 0:
        print(i, deads, torch.min(cg))
#        input()
        n += len(cg_copy.flatten())
#        print(i, 5)
#    print(deads/n)
    return 100*deads/n


def get_path(model, data, target, percentile, shapes, lambd=0.01):
    global control_gates, relu_shapes, relu_counter

    control_gates = []
    relu_counter = 0
    relu_shapes = shapes
    model.apply(init_control_gates)
    model.apply(replace)
    model.apply(collect_control_gates)
    model.cuda()

    optimizer = optim.SGD(control_gates, lr=0.01, momentum=0.9, weight_decay=0)

    data_var = Variable(data).cuda()
    target_var = Variable(torch.tensor(target)).cuda()

    self_predicted_output = model(data_var)
    self_pred = self_predicted_output.data.max(1)[1]
    self_predicted_prob = F.softmax(self_predicted_output)
    self_predicted_prob_var = Variable(self_predicted_prob.data.detach().clone())

    min_loss = 1e10

    cg_list = None
    for i in range(1000):
   # while (cg_list is None) or (cdrp_sparsity(cg_list) < percentile):
#        if cg_list is not None:
#          print(cdrp_sparsity(cg_list))
        output = model(data_var)
        prob = F.softmax(output)

        pred = output.data.max(1)[1]

        loss = - (self_predicted_prob_var * torch.log(prob + 1e-20)).sum(1)

        for v in control_gates:
            loss += lambd * v.abs().sum()

        if pred[0] == self_pred[0]:
            if loss.data[0] < min_loss:
                cg_list = []
                for v in control_gates:
                    cg_list.append(v.data.clone())

                min_loss = loss.data[0]
                best_output = output.data.clone()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        for v in control_gates:
            v.data.clamp_(0, 100)

    model.apply(reset_control_gates)
#    print(len(cg_list))
#    for i in range(len(cg_list)):
#        print(cg_list[i].shape)
#    print(cdrp_sparsity(cg_list))
    return cg_list


#num_layers = len(all_control_gates[0])
#new_all_control_gates = []
#for i in range(num_layers):
#    new_all_control_gates.append([])

#for value in all_control_gates:
#    for i, v in enumerate(value):
#        new_all_control_gates[i].append(v.cpu().numpy())

#for i, v in enumerate(new_all_control_gates):
#    new_all_control_gates[i] = np.vstack(v)

#all_targets = np.array(all_targets)
#misc.dump_pickle([new_all_control_gates, all_targets], os.path.join(args.logdir, 'all_infos.pkl'))
