
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from losses import LMSoftmaxLoss
from losses import NormFaceLoss as NormLoss
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

seed = 123
np.random.seed(seed)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
device = 'cuda' if torch.cuda.is_available()  else 'cpu'
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

n_classes = 10
n_hiddens = 2
n_samples = n_classes * 100

n_per_class = n_samples // n_classes

labels = []
for i in range(n_classes):
    labels += [i] * n_per_class
np.random.shuffle(labels)
labels = torch.LongTensor(labels).cuda()


class MeanMarginLoss(nn.Module):
    def __init__(self):
        super(MeanMarginLoss, self).__init__()

    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=1)
        l2 = torch.mean(logits, dim=1)
        loss = -l1 + l2
        return loss.mean()

class MarginLoss(nn.Module):
    def __init__(self):
        super(MarginLoss, self).__init__()

    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=1)
        tmp = logits * (1 - label_one_hot) - label_one_hot
        l2, _ = torch.topk(tmp, 1, dim=1)
        loss = -l1 + torch.mean(l2)
        return loss.mean()

def evaluate(out, labels):
    probs = F.softmax(out, dim=1)
    pred = torch.argmax(probs, 1)
    total = labels.size(0)
    correct = (pred==labels).sum().item()

    acc = float(correct) / float(total)
    return acc

def get_margin(weight):
    tmp = F.normalize(weight, dim=1)
    similarity = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
    similarity = torch.clamp(similarity, -1+1e-7, 1-1e-7)
    return torch.acos(torch.max(similarity)).item() / math.pi * 180

def get_weight_margin(weight):
    tmp = F.normalize(weight, dim=1)
    similarity = torch.matmul(tmp, tmp.transpose(1, 0))
    similarity = similarity * ( 1 - torch.eye(tmp.size(0), device=weight.device))
    return torch.mean(similarity)

def norm_weights(weights):
    weights_norm = F.normalize(weights, dim=1)
    gravity = torch.mean(weights_norm, dim=0)
    return torch.sum(gravity ** 2)


def norm(weight):
    epsilon = 1e-5
    tmp = F.normalize(weight, dim=1)
    similarity = torch.matmul(tmp, tmp.transpose(1, 0))
    similarity = similarity * (1 - torch.eye(tmp.size(0), device=weight.device))
    return torch.mean(1. / (1 + epsilon - similarity))



Z = torch.randn(n_samples, n_hiddens).cuda()
Z.requires_grad = True
W = torch.randn(n_classes, n_hiddens).cuda()
W.requires_grad = True
nn.init.kaiming_uniform_(W)

optimizer = torch.optim.SGD([Z, W], lr=0.1, momentum=0.9, weight_decay=1e-4)
# criterion = MarginLoss()
# criterion = NormLoss(scale=1)
criterion = LMSoftmaxLoss(scale=1)
scheduler = CosineAnnealingLR(optimizer, T_max=10000, eta_min=0.0)

tmp = F.normalize(W, dim=1)
similarity = torch.matmul(tmp, tmp.transpose(1, 0))
similarity = similarity.cpu().detach().numpy()
similarity = (similarity + 1) / 2
plt.imshow(similarity, cmap='gray')
plt.show()
epochs = 1000000
for ep in range(epochs):
    L2_z = F.normalize(Z, dim=1)
    L2_w = F.normalize(W, dim=1)
    out = F.linear(L2_z, L2_w)
    loss = criterion(out, labels)
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_bound)
    optimizer.step()
    # scheduler.step()
    if ep % 200 ==0:
        test_acc = evaluate(out, labels)
        margin = get_margin(W)
        # criterion = LNormLoss(scale=0.1 + ep / epochs * 5)
        print('Iter {}: loss={:.4f}, test_acc={:.4f}, margin={:.4f}'.format(ep, loss.item(), test_acc, margin))

tmp = F.normalize(W, dim=1)
similarity = torch.matmul(tmp, tmp.transpose(1, 0))
similarity = similarity.cpu().detach().numpy()
np.save('./sim.npy', similarity)

z = Z.cpu().detach().numpy()
np.save('./toy/norm_features_' + str(n_classes) + 'x' + str(n_hiddens) + '.npy', z)

weight = W.cpu().detach().numpy()
np.save('./toy/norm_weights_' + str(n_classes) + 'x' + str(n_hiddens) + '.npy', weight)
similarity = (similarity + 1) / 2

plt.imshow(similarity, cmap='gray')
plt.show()
