import logging
import math
from typing import Tuple

import torch
import torch.nn as nn

import dark_kg.graph.utils as graph_utils


class RelationGraph(nn.Module):
    """
    Ingredient relation graph.
    Parameters:
        vertex_weights: [K, num_vertices], the weight of each vertex for K classes
        edge_weights: [K, num_vertices, num_vertices], K adjacent matrix for each classes
    Weights
        vertex_attribute_weights: [2, 1]
        edge_attribute_weights: [2, 1]
    Tracking State:
        n_tracked: [K], record number of sample tracked for initialization
    Args:
        feat_h: original feature height
        feat_w: original feature width
        constant_vertex_attr: pos 0 for geometric attribution weight, 1 for attention attribution weight
        constant_edge_attr: pos 0 for geometric attribution weight, 1 for attention attribution weight
        clamp_vertex_attn: filter attentions to cls-token
        clamp_edge_attn: filter edge attentions
        remove_self_loop: remove self-loops in IR-graphs
        prune_node_threshold: edges connected to this node will be removed if its weight is too small
    """
    def __init__(
        self,
        num_vertices: int,
        emb_dim: int,
        num_classes: int = 10,
        dist_alpha: float = 1,
        dist_pow: float = 2,
        feat_h: int = 14,
        feat_w: int = 14,
        constant_vertex_attr: Tuple[float, float] = None,
        constant_edge_attr: Tuple[float, float] = None,
        clamp_vertex_attn: float = None,
        clamp_edge_attn: float = None,
        remove_self_loop: bool = False,
        prune_node_threshold: float = None,
        apply_normalize: bool = True,
        clamp_weights: bool = True
    ):
        super().__init__()
        self.logger = logging.getLogger("RelationGraph")
        # config
        self.num_vertices = num_vertices
        self.emb_dim = emb_dim
        self.num_classes = num_classes
        self.dist_alpha = dist_alpha
        self.dist_pow = dist_pow
        self.feat_h = feat_h
        self.feat_w = feat_w
        self.constant_vertex_attr = constant_vertex_attr
        self.constant_edge_attr = constant_edge_attr
        self.clamp_vertex_attn = clamp_vertex_attn
        self.clamp_edge_attn = clamp_edge_attn
        self.remove_self_loop = remove_self_loop
        self.prune_node_threshold = prune_node_threshold
        self.apply_normalize = apply_normalize
        self.clamp_weights = clamp_weights

        self.disable_edge = False

        # parameters
        self.vertex_weights = graph_utils.MyParameter(
            shape=(num_classes, num_vertices),
            as_buffer=False
        )
        self.edge_weights = graph_utils.MyParameter(
            shape=(num_classes, num_vertices, num_vertices),
            as_buffer=False
        )
        # weights
        self.vertex_attribute_weights = graph_utils.MyParameter(
            shape=(2, 1),
            as_buffer=constant_vertex_attr is not None
        )
        self.edge_attribute_weights = graph_utils.MyParameter(
            shape=(2, 1),
            as_buffer=constant_edge_attr is not None
        )
        # tracking state
        self.n_tracked: torch.LongTensor
        self.register_buffer("n_tracked", torch.zeros(num_classes), persistent=False)

        self._reset_parameters()

    def _reset_parameters(self):
        # set weights
        nn.init.constant_(self.vertex_attribute_weights.tensor, 0.5)
        nn.init.constant_(self.edge_attribute_weights.tensor, 0.5)
        # set all values between 0 and 1, mean 0.5, 3*std = 0.5
        nn.init.trunc_normal_(self.vertex_weights.tensor, mean=0.5, std=1 / 6, a=0, b=1)
        nn.init.trunc_normal_(self.edge_weights.tensor, mean=0.5, std=1 / 6, a=0, b=1)
        self.vertex_weights.normalize_sum_(dim=-1)
        self.edge_weights.normalize_sum_(dim=-1)
        if self.constant_vertex_attr is not None:
            init = torch.tensor(self.constant_vertex_attr).reshape(2, 1)
            self.vertex_attribute_weights.copy_(init)
        if self.constant_edge_attr is not None:
            init = torch.tensor(self.constant_edge_attr).reshape(2, 1)
            self.edge_attribute_weights.copy_(init)
        self.normalize()

    @torch.no_grad()
    def normalize(self):
        if self.clamp_weights:
            self.vertex_attribute_weights.tensor.clamp_min_(0.01)
            self.edge_attribute_weights.tensor.clamp_min_(0.01)
        if self.apply_normalize:
            self.vertex_weights.normalize_sum_(dim=-1)
            self.edge_weights.normalize_sum_(dim=-1)
            if self.remove_self_loop:
                self.edge_weights.tensor.diagonal(dim1=1, dim2=2).fill_(0)

    def get_vertex_weights(self, detach: bool = False) -> torch.Tensor:
        vertex_weights = self.vertex_weights.tensor
        if detach:
            vertex_weights = vertex_weights.detach()
        # normalize
        vertex_weights = graph_utils.normalize_sum_clamp(vertex_weights, detach_sum=True, min_val=1.0e-5)
        return vertex_weights

    def get_edge_weights(self, detach: bool = False) -> torch.Tensor:
        edge_weights = self.edge_weights.tensor
        if detach:
            edge_weights = edge_weights.detach()
        # apply mask
        if self.prune_node_threshold is not None:
            with torch.no_grad():
                # if a vertex has weight < threshold, then edge connected to this vertex is set to zero weight
                vertex_weights = self.get_vertex_weights(detach=True)
                mask = (vertex_weights > self.prune_node_threshold).float()
                mask = mask.unsqueeze(-1)
                mask = torch.bmm(mask, mask.transpose(1, 2))
                edge_weights.masked_fill_(~mask.bool(), 0)
            # set the gradients of masked edges to zero as well
            edge_weights = edge_weights * mask
        # normalize
        edge_weights = graph_utils.normalize_sum_clamp(edge_weights, detach_sum=True)

        if self.remove_self_loop:
            with torch.no_grad():
                mask = torch.ones_like(edge_weights)
                mask.diagonal(dim1=1, dim2=2).fill_(0)
            edge_weights = edge_weights * mask
        return edge_weights

    def feat_to_vertices(
        self,
        ingredients: torch.LongTensor,
        attn_cls: torch.Tensor
    ) -> torch.Tensor:
        """
        Convert feature ingredients to vertex weights
        Args:
            ingredients: [bs, L]
            attn_cls: [bs, L]
        Return: node weights for each sample, shape [bs, num_vertices]
        """
        if self.clamp_vertex_attn is not None:
            attn_cls.masked_fill_(attn_cls < self.clamp_vertex_attn, float("-inf"))
        attn_cls = attn_cls.softmax(dim=-1)
        vertices_attr = self._feat_to_v_attr(ingredients, attn_cls)
        # calculate weighted vertex weights
        graph_utils.normalize_max_(vertices_attr, dim=1)
        vertex_weights = vertices_attr @ self.vertex_attribute_weights.tensor
        return vertex_weights.squeeze_(-1)

    def _feat_to_v_attr(
        self,
        ingredients: torch.LongTensor,
        attn_cls: torch.Tensor
    ) -> torch.Tensor:
        """
        Get instance vertex attributes:
            1. count of each ingredient
            2. mean attention to cls token of each ingredient
        Args:
            ingredients: [bs, L]
            attn_cls: [bs, L]
        """
        from cpp_extension import cpp_feat_to_v_attr
        return cpp_feat_to_v_attr(
            ingredients.cpu(),
            attn_cls.cpu(),
            n_vertices=self.num_vertices,
            mean=True
        ).to(ingredients.device)

    def feat_to_edges(
        self,
        ingredients: torch.LongTensor,
        attn: torch.Tensor
    ) -> torch.Tensor:
        """
        Convert feature ingredients to vertex weights
        Args:
            ingredients: [bs, L]
            attn: [bs, L, L]
        Return: weighted edges for each sample, shape [bs, num_vertices, num_vertices]
        """
        if self.disable_edge:
            bs = ingredients.shape[0]
            edge = torch.zeros(bs, self.num_vertices, self.num_vertices, device=ingredients.device)
            return edge

        if self.clamp_vertex_attn is not None:
            attn.masked_fill_(attn < self.clamp_vertex_attn, float("-inf"))
        attn = torch.softmax(attn, dim=-1)
        geo_sim = graph_utils.pair_wise_point_sim(
            h=self.feat_h,
            w=self.feat_w,
            alpha=self.dist_alpha,
            pow=self.dist_pow,
            device=ingredients.device
        )
        # [bs, num_vertices, num_vertices, 2]
        edges_attr = self._feat_to_e_attr(ingredients, attn, geo_sim)
        graph_utils.normalize_sum_(edges_attr, dim=2)
        if self.remove_self_loop:
            edges_attr.diagonal(dim1=1, dim2=2).fill_(0)
        # calculate weighted vertex weights
        edges_attr = edges_attr @ self.edge_attribute_weights.tensor
        return edges_attr.squeeze_(-1)

    def _feat_to_e_attr(
        self,
        ingredients: torch.LongTensor,
        attn: torch.Tensor,
        geo_sim: torch.Tensor
    ) -> torch.Tensor:
        """
        Get instance edge attributes:
            1. geometric similarity
            2. sum attention between each pair
        Args:
            ingredients: [bs, L]
            attn: [bs, L, L]
            geo_sim: [L, L]
        """
        from cpp_extension import cpp_feat_to_e_attr
        edge = cpp_feat_to_e_attr(
            ingredients.cpu(),
            attn.cpu(),
            geo_sim.cpu(),
            n_vertices=self.num_vertices,
            mean=True
        ).to(ingredients.device)
        return edge

    def update(
        self,
        ingredients: torch.LongTensor,
        attn: torch.Tensor,
        attn_cls: torch.Tensor,
        label: torch.LongTensor
    ):
        """
        Args:
            ingredients: [bs, L]
            attn: [bs, L, L]
            attn_cls: [bs, L]
            label: [bs]
        """
        vertices = self.feat_to_vertices(ingredients, attn_cls)
        edges = self.feat_to_edges(ingredients, attn)
        vertices = vertices.to(self.vertex_weights.tensor.device)
        edges = edges.to(self.edge_weights.tensor.device)
        with torch.no_grad():
            for cls_id, instance_v, instance_e in zip(label, vertices, edges):
                self.vertex_weights.tensor[cls_id] += instance_v
                self.edge_weights.tensor[cls_id] += instance_e
                self.n_tracked[cls_id] += 1

    def accumulate(self):
        with torch.no_grad():
            self.vertex_weights.tensor /= self.n_tracked[:, None]
            # normalize
        self.normalize()

    def forward(
        self,
        ingredients: torch.LongTensor,
        attn: torch.Tensor,
        attn_cls: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Convert ingredient sequence to instance relation graphs
        Args:
            ingredients: [bs, L]
            attn: [bs, L, L]
            attn_cls: [bs, L]
        Return:
            nodes: [bs, num_vertices]
            edges: [bs, num_vertices, num_vertices]
        """
        vertices = self.feat_to_vertices(ingredients, attn_cls)
        edges = self.feat_to_edges(ingredients, attn)
        return vertices, edges
