import torch.nn as nn
import torch.nn.functional as F

from misc.utils import *
from misc.utils_hyp import is_on_lorentz_manifold
# from misc.lorentz import Lorentz
import math


class GCN(nn.Module):
    def __init__(self, n_feat=10, n_dims=128, n_clss=10, args=None):
        super().__init__()
        self.n_feat = n_feat
        self.n_dims = n_dims
        self.n_clss = n_clss
        self.args = args

        from torch_geometric.nn import GCNConv

        self.conv1 = GCNConv(self.n_feat, self.n_dims, cached=False)
        self.conv2 = GCNConv(self.n_dims, self.n_dims, cached=False)
        self.clsif = nn.Linear(self.n_dims, self.n_clss)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr

        x = F.relu(self.conv1(x, edge_index, edge_weight))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.clsif(x)
        return x
    
    
class HGCN(nn.Module):
    def __init__(self, n_feat=10, n_dims=128, n_clss=10, args=None):
        super().__init__()
        self.n_feat = n_feat
        self.n_dims = n_dims
        self.n_clss = n_clss
        self.args = args
        self.margin = 0.1
        self.fname = self.args.fname
        k = 1.0
        
        from geoopt.manifolds.lorentz import Lorentz
        self.manifold = Lorentz(k)

        from models.layer_hyp import HyperbolicGraphConvolution, HypLinear

        self.conv1 = HyperbolicGraphConvolution(self.manifold, self.n_feat+1, self.n_dims, use_bias=True, dropout=0.1)
        self.conv2 = HyperbolicGraphConvolution(self.manifold, self.n_dims, self.n_dims, use_bias=True, dropout=0.1)

        self.cls = nn.Linear(self.n_dims, self.n_clss)
        

    def forward(self, data, g_id):
        x, edge_index = data.x, data.edge_index

        x = self.manifold.expmap0(x)
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        
        x = self.manifold.logmap0(x)
        x = F.dropout(x, training=self.training)
        output = self.cls(x)

        return output
    
    def compute_lss(self, output, label):
        loss = F.cross_entropy(output, label)
        return loss

class LGCN(nn.Module):
    def __init__(self, k=1.0, n_feat=10, n_dims=128, n_clss=10, args=None):
        super().__init__()
        self.n_feat = n_feat
        self.n_dims = n_dims
        self.n_clss = n_clss
        self.args = args
        self.margin = 0.1
        self.fname = self.args.fname
        self.grad_k = self.args.grad_k

        self.k = nn.Parameter(k.clone().detach().log().requires_grad_(self.grad_k))

        from models.layer_hyp import LGCNConv, LorentzLinear
        
        self.conv1 = LGCNConv(self.n_feat+1, self.n_dims, use_bias=True, dropout=0.1, rescale=self.args.rescale)
        self.conv2 = LGCNConv(self.n_dims, self.n_dims, use_bias=True, dropout=0.1, rescale=self.args.rescale)

        self.cls = nn.Linear(self.n_dims, self.n_clss)
            
    def cls_proj(self, cls_data, proj_method='1'):
        
        if proj_method == '1': 
            norm = torch.sum(cls_data[:, 1:] ** 2, dim=1) 
            time = cls_data[:, 0] ** 2
            results = time - norm

            if not torch.all(torch.isclose(results, self.manifold.k.data, atol=1e-03)):
                factor = (self.manifold.k.data / results) ** 0.5
                factor = factor.view(self.n_clss, 1)
                cls_data = cls_data * factor
                print('nets.py line100: modify cls....')

        elif proj_method == '2':
            for i in range(self.n_clss):
                cls_data[i] = self.manifold.projx(cls_data[i])
        return cls_data

    
    def cinner(self, x: torch.Tensor, y: torch.Tensor):
        x = x.clone()
        x.narrow(-1, 0, 1).mul_(-1)
        return x @ y.transpose(-1, -2)
        

    def forward(self, data, k, g_id):
        x, edge_index = data.x, data.edge_index

        k = torch.sigmoid(self.k.exp() - 0.5) + self.args.bc

        from geoopt.manifolds.lorentz import Lorentz
        self.manifold = Lorentz(k)

        x = self.manifold.expmap0(x)
        is_on_lorentz_manifold(x, self.manifold.k)

        x = self.conv1(x, edge_index, k)
        x = self.conv2(x, edge_index, k)
        x = self.manifold.logmap0(x)

        x = F.dropout(x, training=self.training)
        output = self.cls(x)
        return output
    
    def compute_lss(self, output, label):
        loss = F.cross_entropy(output, label)
        return loss
    

