import os.path

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pointnet2_utils import farthest_point_sample, index_points, knn2, l2_normalize, distributed_sinkhorn, \
    distributed_sinkhorn_topk, distributed_sinkhorn_l1, distributed_sinkhorn_topk_grad

from pointnet_PDG import Feature_Extractor, Classifier, Projecter
from pointnet2_utils import momentum_update
from geoopt.manifolds.stereographic import PoincareBall

def Poincare_dist(x, y, c=1.0):
    """
    Computes Poincare distance btw x and y
    Args:
        x (torch.Tensor): shape (n, d). n usually n_way*n_query
        y (torch.Tensor): shape (m, d). m usually n_way
    Returns:
        torch.Tensor: shape(n, m). For each query, the distances to each centroid
    """
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)
    manifold = PoincareBall(c=c)

    # x = manifold.projx(x)
    # y = manifold.projx(y)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    hx = manifold.projx(x)
    hy = manifold.projx(y)

    return manifold.dist2(hx, hy)


class LGA(nn.Module):
    def __init__(self, out_dim, alpha, beta):
        super().__init__()

    def forward(self, lc_xyz, lc_x, knn_xyz, knn_x):
        knn_x = knn_x.permute(0,2,1,3)
        # Normalize x (features) and xyz (coordinates)
        mean_x = lc_x.unsqueeze(dim=2)
        std_x = torch.std(knn_x - mean_x)
        knn_x = (knn_x - mean_x) / (std_x + 1e-5)

        return knn_x

class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc_q.weight)
        nn.init.xavier_uniform_(self.fc_k.weight)
        nn.init.xavier_uniform_(self.fc_v.weight)
        nn.init.xavier_uniform_(self.fc_o.weight)
        nn.init.constant_(self.fc_q.bias, 0)
        nn.init.constant_(self.fc_k.bias, 0)
        nn.init.constant_(self.fc_v.bias, 0)
        nn.init.constant_(self.fc_o.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, mode='known'):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[0]

        # dot
        # q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        # k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        # v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        # #cos
        # keys = (keys).unsqueeze(0).repeat(b_s, 1, 1)
        # q = l2_normalize(queries.view(b_s, nq, self.d_k))  # (b_s, h, nq, d_k)
        # k = l2_normalize(keys.view(b_s, nk, self.d_k)).permute(0, 2, 1)  # (b_s, h, d_k, nk)
        # values = (values).unsqueeze(0).repeat(b_s, 1, 1)
        # att = torch.matmul(q, k) #/ np.sqrt(self.d_k)  # (b_s, h, nq, nk) #FIXME cos similarity

        # hyperbolic
        q = queries.view(-1, self.d_model)  # (b_s, h, nq, d_k)
        k = keys.view(-1, self.d_model) # (b_s, h, d_k, nk)
        q = self.fc_q(q)
        k = self.fc_k(k)
        att = Poincare_dist(q, k)

        if mode == 'known':
            with torch.no_grad():
                # sk_att = distributed_sinkhorn_topk(att.detach().reshape(-1, att.shape[-1]), 30, sparsity=5)  # q:n,m  index:n
                # sk_att = sk_att.reshape(b_s, self.h, nq, nk)
                # topk_values, _ = torch.topk(att, k=5, dim=-1)
                # attention_mask = att < topk_values[..., [-1]]
                sk_att = None
                attention_mask = None #att<0
                # attention_mask = sk_att == 0
        else:
            attention_mask = None
            sk_att = None

        att2 = att
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att2 = att2.masked_fill(attention_mask, -np.inf)
        att2 = torch.softmax(att2, -1)
        out = torch.matmul(att2, values).contiguous()# (b_s, nq, h*d_v)

        att2 = att2.view(b_s, nq, -1)
        out = out.view(b_s, nq, self.d_model)
        return out, att2, sk_att


class get_part_feat_relate(nn.Module):
    def __init__(self, k=8, num_points=512, emb_dim=1024 ):
        super(get_part_feat_relate, self).__init__()
        self.CrossAtt = ScaledDotProductAttention(d_model=emb_dim, d_k=emb_dim, d_v=emb_dim, h=1)
        self.PartFormer = nn.TransformerEncoderLayer(emb_dim, 4, emb_dim, 0.5, batch_first=True)
        self.PosEm = LGA(emb_dim*2, 1000, 1000)
        self.k = k
        self.num_points = num_points

        self.part_projection = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim),
            # nn.ReLU(),
        )

    def cal_delta(self, part_xyz, cent_xyz, part_feat):
        xyz_off = cent_xyz.unsqueeze(2) - part_xyz
        feat_off = self.delta_net((part_feat[:,:1,:,:] - part_feat).reshape(-1, part_feat.shape[-1])).reshape(-1, self.num_points, self.k, 3)
        delta_xyz = torch.mean(feat_off.permute(0,2,1,3) * xyz_off, dim=2)
        return delta_xyz

    def forward(self, xyz, feat_xyz, temp, cent_xyz=None, mode='known'):
        B = feat_xyz.shape[0]
        xyz = xyz.transpose(1, 2)
        feat_xyz = feat_xyz.transpose(1, 2)
        if cent_xyz == None:
            cent_index= farthest_point_sample(xyz, self.k)
            cent_xyz = index_points(xyz, cent_index)
        id = knn2(xyz, cent_xyz, self.num_points)
        part_feat = index_points(feat_xyz, id)
        center_feat = index_points(feat_xyz, cent_index)

        part_xyz = index_points(xyz, id)
        part_xyz = part_xyz.transpose(1, 2).contiguous()
        part_feat_max = torch.max(part_feat, 1)[0]

        part_related = self.PosEm(cent_xyz, center_feat, part_xyz, part_feat)
        part_related = self.part_projection(part_related.reshape(part_related.shape[0] * self.k * self.num_points ,-1)).reshape(B, self.k, self.num_points, -1).contiguous()
        part_related = torch.max(part_related, 2)[0]
        part_relate_emb = self.PartFormer(part_feat_max)

        transformed_part_feat, att2, part_target = self.CrossAtt(part_related, temp, temp, mode= mode)

        return transformed_part_feat, cent_xyz, part_feat_max, part_xyz, att2, part_target, part_relate_emb, part_related

class get_model(nn.Module):
    def __init__(self, args, num_class, num_unknown_class, normal_channel=False):
        super(get_model, self).__init__()
        self.args = args
        args.emb_dims = 256

        model_name = 'pointnet'
        if model_name == 'pointnet':
            self.encoder = Feature_Extractor(args.emb_dims)

        elif model_name == 'dgcnn':
            self.encoder = Feature_Extractor_DGCNN(args.emb_dims)

        elif model_name == 'pointnext':
            self.encoder = pointnext_PDG.get_model()

        # elif model_name == 'pointnet2':

        self.part_prototypes = nn.Parameter(torch.Tensor([5 * num_class, args.emb_dims]), requires_grad=True)
        self.get_part_feat = get_part_feat_relate(k=args.nb_primitives, num_points=args.number_points,
                                                  emb_dim=args.emb_dims)

        self.projection = nn.Sequential(
            nn.Linear(args.emb_dims*3 , args.emb_dims),
            nn.BatchNorm1d(args.emb_dims),

            nn.ReLU(),
            nn.Linear(args.emb_dims, args.emb_dims),
            nn.BatchNorm1d(args.emb_dims),
            # nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(args.emb_dims, args.emb_dims),
            # nn.ReLU(),
        )


        self.cat_prototypes = nn.Parameter(torch.zeros(num_class, 256),
                                           requires_grad=True)
        self.novel_prototypes = nn.Parameter(torch.zeros(num_unknown_class, 256),
                                             requires_grad=True)
        self.train_flag = True
        #
        torch.nn.init.xavier_normal_(self.part_prototypes)
        torch.nn.init.xavier_normal_(self.cat_prototypes)
        torch.nn.init.xavier_normal_(self.novel_prototypes)

        self.gamma = 0.6

    def prototype_learning(self, att, _c, gamma=0.9):
        with torch.no_grad():
            protos = self.part_prototypes.data.clone()
            _c = _c.reshape(-1, _c.shape[-1])
            # q, indexs = distributed_sinkhorn_topk(proto_logits, 30, sparsity=3) # q:n,m  index:n

            # a, b = ot.unif(_c.size()[0]), ot.unif(self.part_prototypes.size()[0])
            # pi = ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, proto_logits.detach().cpu().numpy(), reg=0.1, reg_m=0.5)
            # indexs = torch.argmax(torch.tensor(pi).cuda(),dim=1)
            # f = q.transpose(0, 1) @ _c  # self.
            # 　　　　　　　　　　　　　　　　num_prototype x embedding_dim

            # f = F.normalize(f, p=2, dim=-1)
            # 不使用最优运输
            att = att.reshape([-1, att.shape[-1]])
            indexs = torch.argmax(att, dim=1)  # n
            q = F.one_hot(indexs, num_classes=len(self.part_prototypes)).float()  # n,m
            f = q.T @ _c / (q.sum(0).unsqueeze(-1) + 0.0001)
            f = l2_normalize(f)
            # print(torch.where(q.sum(0) > 0)[0], len(torch.where(q.sum(0) > 0)[0]))
            new_value = momentum_update(old_value=protos, new_value=f, momentum=gamma, debug=False)
            # proto_target = indexs.float()
            self.part_prototypes = nn.Parameter(new_value, requires_grad=False)
            return q

    def update_usage(self, part_score, train=False):
        # part_score B,part,code
        if train:
            with torch.no_grad():
                part_score = part_score.reshape(-1, part_score.shape[-1])
                topk_values, topk_indices = torch.topk(part_score, 2, dim=1)
                result = torch.sum(torch.where(part_score.scatter(1, topk_indices, 1)== 1, 1, 0), dim=0)
                mask = (result != 0).float()

                self.code_usage = 0.99 * self.code_usage + (1 - 0.99) * mask
                # print('update_usage', self.code_usage)
                # self.code_usage_cur[self.code_usage_cur != 0] = 0
                #
                # self.code_usage_cur += mask

    def update_usage_epoch(self):
            cur_mask = (self.code_usage_cur != 0).float()
            self.code_usage = 0.99 * self.code_usage + (1 - 0.99) * cur_mask
            print('update_usage', self.code_usage)
            self.code_usage_cur[self.code_usage_cur != 0] = 0
            # self.code_usage += torch.mean(torch.masked_fill(part_score, mask, 0), dim=0)
            # self.code_usage /= 2
    def _tile(self, x):
        n, c = x.shape
        code_size = self.part_prototypes.shape[0]
        idx = torch.randperm(n)
        x = x[idx]
        if n < code_size and n>0:
            n_repeats = code_size // n
            remainder = code_size % n
            std = 0.01 / np.sqrt(c)
            x = torch.cat([x] * n_repeats + [x[:remainder]], dim=0)
            x = x + torch.randn_like(x) * std
        elif n == 0:
            x = self.part_prototypes
        else:
            x = x[:code_size]
        return x

    def update_code(self, dis_sim_part_feat, train):
        if train:
            with torch.no_grad():
                mask = (self.code_usage < 0.4).float().unsqueeze(-1)
                if len(dis_sim_part_feat) > 0 and torch.sum(mask>0) > 0:
                    # print('update_code', torch.sum(mask>0), self.code_usage)
                    dis_sim_part_feat = self._tile(dis_sim_part_feat)
                    protos = self.part_prototypes.data.clone()
                    protos = (1-mask) * protos + mask * dis_sim_part_feat
                    self.part_prototypes = nn.Parameter(protos, requires_grad=True)
                    self.code_usage[mask.squeeze().bool()] = 1
            # usage也要变一下

    def multi_pool(self, feat, part_pos_emb, part_relate):
        B, N, D = feat.shape
        feat = torch.cat([feat, part_pos_emb, part_relate], dim=-1).reshape(B * N, -1)
        feat = self.projection(feat).reshape(B, N, -1)
        pooled_x_max = torch.max(feat, 1)[0]
        return pooled_x_max

    def forward(self, xyz, gt, seg, epoch, flag=0, purn=False, cent_idx=None):
        B = xyz.shape[0]
        points_feat = self.encoder(xyz)
        if flag == 0:
            q1, cent_q1, part_feat, part_xyz, att, part_target, part_pos_emb, part_related = self.get_part_feat(xyz, points_feat, self.part_prototypes, cent_idx, mode='known')
        elif flag== 1:
            q1, cent_q1, part_feat, part_xyz, att, part_target, part_pos_emb, part_related = self.get_part_feat(xyz, points_feat, self.part_prototypes.detach(), cent_idx, mode='unknown')
        part_xyz = part_xyz.view(B, -1, 3)
        part_score = att.sum(1)# b,100

        if flag == 0:

            feat_q1 = self.multi_pool(q1, part_pos_emb, part_related)

            feat_q1 = l2_normalize(feat_q1)
            out_q1 = torch.einsum('bc,kc->bk', feat_q1,
                                 l2_normalize(self.cat_prototypes))

            return {'logits': out_q1,
                    'part_logits': att,
                    'part_target': part_target,
                    'vis_emd': feat_q1,
                    'part_protos': l2_normalize(self.part_prototypes),
                    'feature': feat_q1,
                    'res_points': None,
                    'part_xyz': part_xyz,
                    'cat_protos': l2_normalize(self.cat_prototypes),
                    'novel_protos': l2_normalize(self.novel_prototypes),
                    'ori_part': part_feat,
                    'com_part':q1,
                    'part_score': part_score,
                    'cent_idx':cent_q1
                    }

        elif flag == 1:

            feat_q1 = self.multi_pool(q1,  part_pos_emb, part_related)

            feat_q1 = l2_normalize(feat_q1)
            out_q1 = torch.einsum('bc,kc->bk', feat_q1, l2_normalize(self.novel_prototypes))

            return {'logits': out_q1,
                    'part_logits':att,
                    'vis_emd': feat_q1,
                    'part_protos': l2_normalize(self.part_prototypes),
                    'feature': feat_q1,
                    'res_points': None,
                    'part_xyz': part_xyz,
                    'cat_protos': l2_normalize(self.cat_prototypes),
                    'novel_protos': l2_normalize(self.novel_prototypes),
                    'com_part': q1,
                    'ori_part': part_feat,
                    'part_score': part_score,
                    'cent_idx': cent_q1
                    }



