import torch
import argparse

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import get_son2parent, GraphDataset, load_edge_list


def similarity(embs):

    norm = torch.norm(embs, dim=1)  # each row is norm 1.

    t1 = norm.unsqueeze(1)                                          # n_cls x 1
    t2 = norm.unsqueeze(0)                                          # 1 x n_cls
    denominator = torch.matmul(t1, t2)                              # n_cls x n_cls, each element is a norm product
    numerator = torch.matmul(embs, embs.t())            # each element is a in-prod
    cos_sim = numerator / denominator                               # n_cls x n_cls, each element is a cos_sim
    cos_sim_off_diag = cos_sim - torch.diag(torch.diag(cos_sim))
#    obj = cos_sim_off_diag.max(dim=1)[0]
    obj = cos_sim_off_diag

    return obj.sum()
    
parser = argparse.ArgumentParser()
parser.add_argument('--csv_file', type=str, required=True)
parser.add_argument('--dim', type=int, required=True)
parser.add_argument('--lr', type=float, default=1e-4)
args = parser.parse_args()

if __name__ == '__main__':

    tree_file = args.csv_file

    son2parent = get_son2parent(tree_file)

    _, objects, _ = load_edge_list(son2parent)

    n_cls = len(son2parent) + 1

    # model = Embedding(n_cls, args.dim).cuda()
    embs = nn.Parameter(torch.Tensor(n_cls, args.dim))
    nn.init.normal_(embs, 0, 0.01)

    loss = 0

    for iter in range(1000):

        embs = embs.detach() / torch.norm(embs.detach(), dim=0, keepdim=True)
        embs = nn.Parameter(embs)
        optimizer = optim.Adam([embs], args.lr)

        optimizer.zero_grad()
        # embs = F.normalize(embs, dim=1)
        obj = similarity(embs)
        # normalize embs
        
        obj.backward()
        # import pdb
        # pdb.set_trace()
        optimizer.step()

        loss = obj.item()
        # ema smooth
        loss_ema = loss if iter == 0 else loss_ema * 0.9 + loss * 0.1

        if iter % 10 == 0:
            print('iter', iter, 'loss', loss_ema)


    outfile = './embs/' + args.csv_file.split('.')[0] + f'sph_{args.dim}d' + '.pth'
    save_dict = {'objects': objects, 'embeddings':embs.detach()}
    torch.save(save_dict, outfile)


