from typing import Optional
import numpy as np
import torch.nn as nn
import torch
from torch_geometric.utils import degree
from torch_scatter import scatter_min, scatter_max, scatter_mean
from .quantizer_table import TableQuantizer
from torch_sparse import SparseTensor
from torch import Tensor
import math
from abc import *


class QuantizeChannelTable(nn.Module): # Table for PTQ
    def __init__(self,
                # used in quantizer
                num_bits: int,
                dependency: bool,
                signed: bool,
                use_ste: bool = False,
                symmetric: bool = False,
                percentile: float = -1,
                # for the table
                use_momentum: bool = False,
                momentum: float = 0.01,
                device=None,
                nns=False,
                **kwargs
                 ):
        super(QuantizeChannelTable, self).__init__()
        self.register_buffer("inf_table", torch.tensor([[], []], device=device)) 
        self.register_buffer("non_table", torch.tensor([[], []], device=device)) 

        self.quantizer = TableQuantizer( 
            num_bits=num_bits,
            signed=signed,
            use_ste=use_ste,
            symmetric=symmetric,
        )

        self.percentile = percentile
        self.dependency = dependency
        self.device=device
        self.update = True
        self.reduce = False
        self.mul_dim = False


        self.floating_point = False
        if num_bits == 32:
            self.floating_point = True

        # momentum factor (when used)
        self.momentum_min_max = use_momentum
        self.momentum = momentum

        # quantizer allocation
        self.nns = nns
        self.saved_inf_val = None
        self.saved_inf_ind = None
        self.saved_non_val = None
        self.saved_non_ind = None
        
        
    def freeze_quantization_parameters(self):
        self.update = False
        self.quantizer.eval()
        self.saved_inf_val = None
        self.saved_inf_ind = None
        self.saved_non_val = None
        self.saved_non_ind = None
    
    def unfreeze_quantization_parameters(self):
        self.update = True
        self.quantizer.train()
        self.saved_inf_val = None
        self.saved_inf_ind = None
        self.saved_non_val = None
        self.saved_non_ind = None

        
    @staticmethod
    @abstractmethod
    def degree_based_norm(input, indegree):
        pass

    @staticmethod
    @abstractmethod
    def revert_degree_based_norm(input, indegree):
        pass

    
    def update_param(self, hop_max, hop_min, need_updates, inf=True):
        # hop_max, hop_min -> for calculating zero_point, abs for scale
        # need_updates : mask to update 
        if not self.update: return
        table = self.inf_table if inf else self.non_table
        hop_max = hop_max.clone().detach()
        hop_min = hop_min.clone().detach()

        len = hop_max.shape[0]
        if table.shape[1] < len:
            new = len - table.shape[1]
            tmp = torch.full([2, new], fill_value=float('nan'),  device=self.device) # initialize as 0
            table = torch.cat((table, tmp), dim=1) 
            assert table.shape[1] == len
        
        # masks
        nan_mask = torch.isnan(table[0])
        new_mask = nan_mask & need_updates
        new_hop = nan_mask[:len] & need_updates[:len]
        notnew_mask = ~nan_mask & need_updates
        notnew_hop = ~nan_mask[:len] & need_updates[:len]

        # update new params which were unseen categories
        table[0][new_mask] = hop_max[new_hop]
        table[1][new_mask] = hop_min[new_hop]
        # update params with seen values
        if self.momentum_min_max:
            table[0][notnew_mask] = table[0][notnew_mask] + self.momentum * (hop_max[notnew_hop] - table[0][notnew_mask])
            table[1][notnew_mask] = table[1][notnew_mask] + self.momentum * (hop_min[notnew_hop] - table[1][notnew_mask])
        else:
            table[0][notnew_mask] = torch.maximum(hop_max[notnew_hop], table[0][notnew_mask])
            table[1][notnew_mask] = torch.minimum(hop_min[notnew_hop], table[1][notnew_mask])
        
        if inf:
            self.inf_table = table
        else:
            self.non_table = table


    
    def node_operation_minmax(self, _input, mask, hop_num, indegree, inf=True):
        table = self.inf_table if inf else self.non_table

        if (not inf) and not self.update and (hop_num.max() > table.shape[1]):
            print(hop_num.max())
            print(table.shape[1])
        
        unique_hop = torch.unique(hop_num[mask]).to(torch.long)
        if inf: _input[mask] = self.degree_based_norm(_input[mask], indegree[mask])

        hop_min, _ = scatter_min(_input[mask].flatten(), torch.transpose(hop_num[mask].repeat([_input.shape[1], 1]), 1, 0).flatten())
        hop_max, _ = scatter_max(_input[mask].flatten(), torch.transpose(hop_num[mask].repeat([_input.shape[1], 1]), 1, 0).flatten())

        if (self.percentile > 0) and (hop_min.shape[0] > 0):
            ranges = torch.stack([hop_min, hop_max]).transpose(0, 1)
            q = torch.tensor([self.percentile, 1-self.percentile])
            hop_min, hop_max = torch.quantile(ranges, q.to(ranges.device), dim=1)

        hop_max = torch.maximum(hop_max, torch.zeros(hop_max.shape, device=self.device))
        hop_min = torch.minimum(hop_min, torch.zeros(hop_min.shape, device=self.device))
        
        # make mask to update in hop scale
        update_mask = torch.zeros(max(table.shape[1], hop_max.shape[0]), device=self.device)
        update_mask[unique_hop] = 1
        self.update_param(hop_max.clone().detach(), hop_min.clone().detach(), update_mask.to(torch.bool), inf=inf) # update quantization parameters
        table = self.inf_table if inf else self.non_table


        max_val = table[0][hop_num[mask]]
        min_val = table[1][hop_num[mask]]

        
        return max_val, min_val


    def quantize(self, _input, node_wise, q_group=None):
        # _input : the activation matrix. 
        # node_wise = True : the shape of activation is (node_num, dim)
        # node_wise = False: the shape of activation is (edge_num, dim)
        # inf_mask, hop_num : node-wise mask
        assert q_group is not None
        index, inf_mask, hop_num, indegree = q_group
        assert indegree is not None
        hop_num = hop_num.to(torch.long)

        
        if not self.update:
            hop_num[inf_mask] = torch.clamp(hop_num[inf_mask], min=0, max=self.inf_table.shape[1]-1)
            hop_num[~inf_mask] = torch.clamp(hop_num[~inf_mask], min=0, max=self.non_table.shape[1]-1)
        
        # make inf_mask, hop_num edge_wise if not node_wise
        if node_wise is False:
            assert index is not None            
            index = index.to(torch.long)
            inf_mask = inf_mask[index]
            hop_num = hop_num[index]
            indegree = indegree[index]
        
        # generate minmax tensors
        max_val = torch.zeros(_input.shape[0], device=self.device)
        min_val = torch.zeros(_input.shape[0], device=self.device)

        max_val[inf_mask], min_val[inf_mask] = self.node_operation_minmax(_input, inf_mask, hop_num, indegree, inf=True)
        max_val[~inf_mask], min_val[~inf_mask] = self.node_operation_minmax(_input, ~inf_mask, hop_num, indegree, inf=False)
        
        # quantize
        _input, params = self.quantizer(_input, max_val, min_val, dependency=self.dependency)
        _input[inf_mask] = self.revert_degree_based_norm(_input[inf_mask], indegree[inf_mask])   
        return _input, params     


    def forward(self, _input, node_wise, edge_weight, edge_index=None, q_group=None):
        # edge_weight = True : the shape of activation is (node_num, dim), BUT the quantization group need to be done by "Target Nodes"
        # edge_weight = False: the shape of activation is (node_num, dim), the quantization group need to be done by "Source Nodes"
        if self.floating_point:
            return _input
        assert not edge_weight or not node_wise

        # for single dimension
        size = _input.shape
        if _input.dim() == 1: 
            _input = _input.unsqueeze(1)
            self.reduce = True
        # for more than 2 dimension
        elif _input.dim() > 2:
            self.mul_dim = True
            _input = _input.view(size[0], -1)

        # for quantizing edge_weights
        inf_mask, hop_num, indegree = q_group
        if not node_wise:
            if isinstance(edge_index, SparseTensor):
                row, col, edge_attr = edge_index.t().coo()
                index = col if edge_weight else row
            else:
                index = edge_index[1] if edge_weight else edge_index[0]
        else:
            index = None
        _input, _ = self.quantize(_input, node_wise, (index, inf_mask, hop_num, indegree))

        if self.reduce:
            _input = _input.squeeze()
        elif self.mul_dim:
            _input = _input.view(size)    
        return _input 


class GCNActivationTable(QuantizeChannelTable):
    def __init__(self, **kwargs):
        super(GCNActivationTable, self).__init__(**kwargs)
        self.reduce = False
    
    @staticmethod
    def degree_based_norm(input, indegree):
        # the parameters need to be masked before entering the function
        input = torch.div(input, torch.transpose((torch.sqrt(indegree).round()).repeat(input.shape[1], 1), 0, 1))
        return input

    @staticmethod
    def revert_degree_based_norm(input, indegree):
        # the parameters need to be masked before entering the function
        input = torch.mul(input, torch.transpose((torch.sqrt(indegree).round()).repeat(input.shape[1], 1), 0, 1))
        return input



class GATActivationTable(QuantizeChannelTable):
    def __init__(self, **kwargs):
        super(GATActivationTable, self).__init__(**kwargs)
        self.mul_dim = False

    
    @staticmethod
    def degree_based_norm(input, indegree):
        # the parameters need to be masked before entering the function
        return input

    @staticmethod
    def revert_degree_based_norm(input, indegree):
        # the parameters need to be masked before entering the function
        return input



class GINActivationTable(QuantizeChannelTable):
    def __init__(self, **kwargs):
        super(GINActivationTable, self).__init__(**kwargs)
    
    @staticmethod
    def degree_based_norm(input, indegree):
        # the parameters need to be masked before entering the function
        input = torch.div(input, torch.transpose(indegree.repeat(input.shape[1], 1), 0, 1))
        return input

    @staticmethod
    def revert_degree_based_norm(input, indegree):
        # the parameters need to be masked before entering the function
        input = torch.mul(input, torch.transpose(indegree.repeat(input.shape[1], 1), 0, 1))
        return input
    

class MergeQuantize(QuantizeChannelTable):
    def __init__(self, norm_name = "GCN", **kwargs):
        super(MergeQuantize, self).__init__(**kwargs)
        self.mul_dim = False
        if norm_name == "GCN":
            self.degree_based_norm = GCNActivationTable.degree_based_norm
            self.revert_degree_based_norm = GCNActivationTable.revert_degree_based_norm
        elif norm_name == "GAT" or norm_name == "GATv2":
            self.degree_based_norm = GATActivationTable.degree_based_norm
            self.revert_degree_based_norm = GATActivationTable.revert_degree_based_norm
        elif norm_name == "GIN":
            self.degree_based_norm = GINActivationTable.degree_based_norm
            self.revert_degree_based_norm = GINActivationTable.revert_degree_based_norm
        else:
            print("NO ARCHITECTURE SUPPORTED FOR: {}".format(norm_name))


    def forward(self, _input, norm, node_wise, edge_index=None, q_group=None):
        # node_wise = True : the shape of activation is (node_num, dim)
        # node_wise = False: the shape of activation is (edge_num, dim)
        # norm has to be shaped [num_edges]
        if self.floating_point:
            return _input, norm
                
        size = _input.shape
        if _input.dim() == 1: 
            _input = _input.unsqueeze(1)
            self.reduce = True
        # for more than 2 dimension
        elif _input.dim() > 2:
            self.mul_dim = True
            _input = _input.view(size[0], -1)

        inf_mask, hop_num, indegree = q_group
        
        if node_wise: # this means that message is node-wise and edge_index is sparsetensor
            row, col, edge_attr = edge_index.t().coo()
        else:
            assert edge_index.shape[0] == 2
            row = edge_index[0]
        index = row
        
        _input, params = self.quantize(_input, node_wise, (index, inf_mask, hop_num, indegree))
        qmin, qmax, scale, zeropoint = params

        
        inf_mask =  inf_mask[index] if not node_wise else inf_mask
        indegree =  indegree[index] if not node_wise else indegree
        scale[inf_mask] = self.revert_degree_based_norm(scale[inf_mask], indegree[inf_mask])   
        _input_int = torch.mul(_input, 1.0/scale).add(zeropoint).round()
        _input_int = torch.clamp(_input_int, min=qmin, max=qmax)
        _input_int = _input_int - zeropoint


        scale = torch.mean(scale, dim=1)
        if node_wise:
            scale = scale[index.to(torch.long)]
        if self.reduce:
            _input_int = _input_int.squeeze()
        elif self.mul_dim:
            _input_int = _input_int.view(size)       
        
        # absorb scale to norm
        if norm.dim() == 1:
            norm = norm * scale
        else:
            norm = norm * scale.repeat(norm.shape[1], 1).transpose(1,0)
        return _input_int, norm
