import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms
import  copy
from    tensorboardX import SummaryWriter
import  random
import  math
class GradientAnalysor():
    def __init__(self, class_num=10):   
        self.pos_grad_list = [[] for _ in range(class_num)] 
        self.neg_grad_list = [[] for _ in range(class_num)] 
        self.pos_accum = [0 for _ in range(class_num)]  
        self.neg_accum = [0 for _ in range(class_num)]  
        self.pos_neg_ratio = [None for _ in range(class_num)] 
        self.label_counter = [0 for _ in range(class_num)]      
    def update(self, gradient_batch, label_batch): 
        try:
            assert(isinstance(gradient_batch, list))
            assert(isinstance(label_batch, list))
        except:
            return 
        batch_size = len(gradient_batch)
        class_num = len(gradient_batch[0])
        for sample_id in range(batch_size):
            gradient = self.__abs(gradient_batch[sample_id])  
            label = label_batch[sample_id]
            self.label_counter[label] += 1 
            for classifier_id in range(len(gradient)):  
                if classifier_id == label:  
                    self.pos_grad_list[classifier_id].append(gradient[classifier_id])
                    self.pos_accum[classifier_id] += gradient[classifier_id]
                else:   
                    self.neg_grad_list[classifier_id].append(gradient[classifier_id])
                    self.neg_accum[classifier_id] += gradient[classifier_id]
                self.pos_neg_ratio[classifier_id] = self.pos_accum[classifier_id] / self.neg_accum[classifier_id] if self.neg_accum[classifier_id] != 0 else None
    def print_for_debug(self):
        print("pos accum", self.pos_accum)  
        print("neg accum", self.neg_accum)  
        print("pos neg ratio", self.pos_neg_ratio) 
        print("label_counter", self.label_counter)
    def __norm(self, list, norm=2):
        sum = 0
        for item in list:
            sum += math.pow(item, norm) 
        return math.pow(sum, 1/norm)
    def __abs(self, list):
        for i in range(len(list)):
            if list[i] >= 0:
                continue
            else:
                list[i] = -list[i]
        return list
class Hook():
    def __init__(self):
        self.m_count = 0    
        self.input_grad_list = []
        self.output_grad_list = []
        self.gradient = None
        self.gradient_list = []
    def has_gradient(self):
        return self.gradient != None
    def get_gradient(self):
        return self.gradient
    def hook_func_tensor(self, grad):
        grad = copy.deepcopy(grad)
        self.gradient = grad.cpu().numpy().tolist() 
        self.m_count += 1
    def hook_func_model(self, module, grad_input, grad_output):
        pass
    def hook_func_operator(self, module, grad_input, grad_output):
        pass
batch_size=200
learning_rate=0.01
epochs=100
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )
    def forward(self, x):
        x = self.model(x)
        return x
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=batch_size, shuffle=True)
if __name__ == "__main__":
    device = torch.device('cuda:0')
    criteon = nn.CrossEntropyLoss().to(device)
    net = MLP().to(device)
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)
    L2_similarityList = []
    cosine_similarityList = []
    hookObj = Hook()
    gradAnalysor = GradientAnalysor()
    count = 0
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(-1, 28*28)
            data, target = data.to(device), target.cuda()
            logits = net(data)
            loss = criteon(logits, target)
            optimizer.zero_grad()    
            hook_handle = logits.register_hook(hookObj.hook_func_tensor)
            loss.backward() 
            if hookObj.has_gradient():
                gradAnalysor.update(hookObj.get_gradient(), target.cpu().numpy().tolist())  
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.item()))
            gradAnalysor.print_for_debug()
            count += 1
            hook_handle.remove()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data = data.view(-1, 28 * 28)
            data, target = data.to(device), target.cuda()
            logits = net(data)
            test_loss += criteon(logits, target).item()
            pred = logits.argmax(dim=1)
            correct += pred.eq(target).float().sum().item()
        test_loss /= len(test_loader.dataset)
        print('\nTest set 1: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))