# codebase from hyperspherical prototype networks, Pascal Mettes, NeurIPS2019
import argparse
import math
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn


def parse_args():
    parser = argparse.ArgumentParser(description="Hyperspherical/hyperbolic prototypes")
    parser.add_argument('-c', dest="classes", default=100, type=int)
    parser.add_argument('-d', dest="dims", default=128, type=int)
    parser.add_argument('-l', dest="learning_rate", default=0.1, type=float)
    parser.add_argument('-m', dest="momentum", default=0.9, type=float)
    parser.add_argument('-e', dest="epochs", default=1000, type=int, )
    parser.add_argument('-s', dest="seed", default=300, type=int)
    parser.add_argument('-r', dest="resdir", default="prototypes", type=str)
    parser.add_argument('-n', dest="nn", default=2, type=int)
    args = parser.parse_args()
    return args


# in dimension d > 2 prototypes can be placed by minimizing their mutual cosine similarities, 
# as proposed in HPN(NeurIPS2019). Since the peBu-loss is a decreasing function of the cosine similarity
# between directional coordinate and prototype, the arguments outlined in HPN(NeurIPS2019) for
# hyperspherical prototypes equally apply to hyperbolic prototypes.

def prototype_loss(prototype):
    # Dot product of normalized prototypes is cosine similarity.
    product = torch.matmul(prototype, prototype.t()) + 1
    # Remove diagnonal from loss.
    product -= 2. * torch.diag(torch.diag(product))
    # Minimize maximum cosine similarity.
    loss = product.max(dim=1)[0]

    return loss.mean(), product.max()


def prototype_unify(num_classes):
    single_angle = 2 * math.pi / num_classes
    help_list = np.array(range(0, num_classes))
    angles = (help_list * single_angle).reshape(-1, 1)

    sin_points = np.sin(angles)
    cos_points = np.cos(angles)

    set_prototypes = torch.tensor(np.concatenate((cos_points, sin_points), axis=1))
    return set_prototypes


#
# Compute the semantic relation loss.
#
def prototype_loss_sem(prototypes, triplets):
    product = torch.matmul(prototypes, prototypes.t()) + 1
    product -= 2. * torch.diag(torch.diag(product))
    loss1 = -product[triplets[:, 0], triplets[:, 1]]
    loss2 = product[triplets[:, 2], triplets[:, 3]]
    # return loss1.mean() + loss2.mean(), product.max()
    return loss1.mean() + loss2.mean()




def initialize_prototype(args): 
    # First Randomly Initialize Prototypes.
    prototypes = torch.randn(args.n_cls, args.feat_dim)
    prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
    print('initializing finished.')

    optimizer = optim.SGD([prototypes], lr=args.learning_rate, momentum=args.momentum)

    if args.feat_dim > 2:
        # Optimize for separation.
        for i in range(args.epochs):
            # Compute loss.
            loss, _ = prototype_loss(prototypes)
            # Update.
            loss.backward()
            optimizer.step()
            # Normalize prototypes again
            prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
            optimizer = optim.SGD([prototypes], lr=args.learning_rate,
                                  momentum=args.momentum)

    elif args.dims == 2:
        prototypes = prototype_unify(args.n_cls)
        prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))

    else:
        raise Exception('Dimension is not correct.')

    return prototypes


#
# Main entry point of the script.
#
if __name__ == "__main__":
    # Parse user arguments.
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device("cuda")
    kwargs = {'num_workers': 64, 'pin_memory': True}

    # Set seed.
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    # Store result.
    os.makedirs(args.resdir, exist_ok=True)
    np.save(os.path.join(args.resdir, "prototypes-%dd-%dc.npy" % (args.feat_dim, args.n_cls)),
            prototypes.data.numpy())
