import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from geoopt.tensor import ManifoldTensor,ManifoldParameter
from geoopt.optim import RiemannianAdam, RiemannianSGD
from geoopt.manifolds import PoincareBallExact

from utils import GraphDataset, get_son2parent

import argparse
from time import time

parser = argparse.ArgumentParser()
parser.add_argument('--csv_file', type=str, help='csv file', required=True)
parser.add_argument('--n_epochs', type=int, default=400, help='number of epochs')
parser.add_argument('--dim', type=int, default=128, help='dimension of embedding', required=True)
parser.add_argument('--c', type=float, default=0.1, help='curvature for label embeddings')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')

args = parser.parse_args()


class Embedding(nn.Module):

    def __init__(self, size, dim, manifold, sparse=True):

        super(Embedding, self).__init__()
        self.dim = dim
        self.nobjects = size
        self.manifold = manifold
        # embs is the embedding to be learned.
        embs = ManifoldTensor(torch.randn(size, dim).cuda(),manifold=manifold.cuda())
        embs = embs.proj_()
        self.embs = ManifoldParameter(embs)
        self.dist = manifold.dist

        self.lossfn = F.cross_entropy
        
    def forward(self, inputs):

        r = 1 / torch.sqrt(self.manifold.c)

        if self.embs.norm(dim=1).max() > r or self.embs.isnan().any():
            import pdb
            pdb.set_trace()
            print('embs norm', self.embs.norm(dim=1).max())
            self.embs = ManifoldParameter(self.manifold.projx(self.embs))
        
        # embs = self.embs.proj_() # 20220916 add this line
        # e = embs[inputs] # 20220916 add this line
        
        e = self.embs[inputs] # 20220916 comment out this line
        o = e.narrow(1, 1, e.size(1) - 1)
        s = e.narrow(1, 0, 1).expand_as(o)
        dists = self.dist(s, o).squeeze(-1)
        
        return -dists

    def embedding(self):
        
        return list(self.embs.parameters())[0].data.cpu().numpy()

    def optim_params(self):
        
        return self.embs.parameters()
    
    def loss(self, preds, targets, weight=None, size_average=True):
        
        return self.lossfn(preds, targets)

    def similarity(self):

        norm = torch.norm(self.embs, dim=1)  # each row is norm 1.

        boundary_radius = torch.sqrt(1/self.manifold.c)

        if (norm.detach() > boundary_radius).any():
            deviation = (norm.sum() - boundary_radius * self.embs.shape[0])
            print('deviation from disk boundary', deviation)

        t1 = norm.unsqueeze(1)                                          # n_cls x 1
        t2 = norm.unsqueeze(0)                                          # 1 x n_cls
        denominator = torch.matmul(t1, t2)                              # n_cls x n_cls, each element is a norm product
        numerator = torch.matmul(self.embs, self.embs.t())            # each element is a in-prod
        cos_sim = numerator / denominator                               # n_cls x n_cls, each element is a cos_sim
        cos_sim_off_diag = cos_sim - torch.diag(torch.diag(cos_sim))
    #    obj = cos_sim_off_diag.max(dim=1)[0]
        obj = cos_sim_off_diag
        
        obj = torch.max(torch.zeros_like(obj), obj - 0.8)

        return obj.sum()


class Options:
    def __init__(self):
        # Training options
        self.sparse = False
        self.lr = args.lr
        self.batchsize = 50
        self.nnegs = 50
        self.epochs = args.n_epochs
        self.burnin = 20
        self.dampening = 0.75
        self.ndproc = 8

        # Manifold options
        self.dim = args.dim
        self.c = args.c                  # curvature
        self.T = 1                    # temperature
        self.manifold = PoincareBallExact(c=self.c)
        
opt = Options()


if __name__ == '__main__':

    # tree_file = './cifar100_hierarchy.csv'
    tree_file = args.csv_file

    son2parent = get_son2parent(tree_file)

    dataset = GraphDataset(son2parent, opt)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=opt.ndproc)

    # find the node that has never been a parent
    model = Embedding(
        len(dataset.objects),
        opt.dim,
        opt.manifold,
        sparse=opt.sparse,
    ).cuda()

    epoch_loss1 = torch.Tensor(len(dataset))
    epoch_loss2 = torch.Tensor(len(dataset))
    counts = torch.zeros(model.nobjects, 1).cuda()

    for epoch in range(opt.epochs):
        st = time()

        epoch_loss1.fill_(0)
        epoch_loss2.fill_(0)
        dataset.burnin = False
        lr = opt.lr

        optimizer = RiemannianAdam(model.parameters(), lr=lr, stabilize=10)

        if epoch < opt.burnin:
            _lr_multiplier = 0.1
            dataset.burnin = True
            lr = opt.lr * _lr_multiplier

            optimizer = RiemannianAdam(model.parameters(), lr=lr, stabilize=10)

        # for i in range(len(dataset)):
        for i, (inputs, targets) in enumerate(data_loader):

            inputs = inputs.view(-1, opt.nnegs + 2)
            targets = targets.view(-1)

            # if model.embs.isnan().any():
                # import pdb
                # pdb.set_trace()

            # inputs are the indices of the class_names
            # targets are always zero, refer to forward() in Embedding class
            inputs, targets = dataset.get_batch()

            inputs = inputs.cuda()
            targets = targets.cuda()

            optimizer.zero_grad()
            preds = model(inputs)

            loss1 = model.loss(preds, targets, size_average=True)
            loss2 = 1e-1 * model.similarity()

            epoch_loss1[i] = loss1.cpu().item()
            epoch_loss2[i] = loss2.cpu().item()

            # if epoch_loss1[i].isnan() or epoch_loss2[i].isnan():
                # import pdb
                # pdb.set_trace()
                # bad_pos = torch.where(model.embs.norm(dim=1).isnan())[0][0]

            loss = loss1 + loss2

            loss.backward()

            # if model.embs.isnan().any():
                # import pdb
                # pdb.set_trace()

            optimizer.step()

            # if model.embs.isnan().any():
                # optimizer.param_groups[0]['params'][0].isnan().any() TRUE 
                # models.embs.isnan().any() FALSE ?????
                # import pdb
                # pdb.set_trace()

        epoch_time = time() - st
        if epoch % 5 == 0:
            print('epoch:',epoch,'loss1:',epoch_loss1.mean(),'loss2:',epoch_loss2.mean(),'epoch time:',epoch_time)
        if epoch_loss1.mean() < 1e-5:
            break
        gc.collect()

    outfile = './embs/' + args.csv_file.split('.')[0] + f'_c{args.c}' + f'_{args.dim}d' + '.pth'
    save_dict = {'objects': dataset.objects, 'embeddings':model.embs.detach()}
    torch.save(save_dict, outfile)