"""
Copyright 2020 Twitter, Inc.
SPDX-License-Identifier: Apache-2.0
"""
import numpy as np
import torch_sparse
from locale import normalize
import torch
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import to_dense_adj, to_undirected, remove_self_loops, add_remaining_self_loops
from tqdm import trange
import torch.nn.functional as F
from sklearn.decomposition import PCA
from utils import get_symmetrically_normalized_adjacency, get_row_normalized_adjacency, knn_fast
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

def random_filling(X):
    return torch.randn_like(X)

def zero_filling(X):
    return torch.zeros_like(X)

def mean_filling(X, feature_mask):
    n_nodes = X.shape[0]
    return compute_mean(X, feature_mask).repeat(n_nodes, 1)

def neighborhood_mean_filling(edge_index, X, feature_mask):
    n_nodes = X.shape[0]
    X_zero_filled = X
    X_zero_filled[~feature_mask] = 0.0

    ## og version
    # edge_values = torch.ones(edge_index.shape[1]).to(edge_index.device)
    # edge_index_mm = torch.stack([edge_index[1], edge_index[0]]).to(edge_index.device)
    # D = torch_sparse.spmm(edge_index_mm, edge_values, n_nodes, n_nodes, feature_mask.float())
    # mean_neighborhood_features = torch_sparse.spmm(edge_index_mm, edge_values, n_nodes, n_nodes, X_zero_filled) / D

    ## improved
    edge_values = torch.ones(edge_index.shape[1]).to(edge_index.device)
    edge_index_mm = torch.stack([edge_index[1], edge_index[0]])
    adj_tmp = torch.sparse.FloatTensor(edge_index_mm, values=edge_values, size=(n_nodes, n_nodes))
    numerator = torch.sparse.mm(adj_tmp, X_zero_filled)
    denomiator = torch.sparse.mm(adj_tmp, feature_mask.float())
    mean_neighborhood_features = numerator / denomiator

    # If a feature is not present on any neighbor, set it to 0
    mean_neighborhood_features[mean_neighborhood_features.isnan()] = 0

    del X_zero_filled
    del edge_values
    del edge_index_mm
    del adj_tmp
    del numerator
    del denomiator
    torch.cuda.empty_cache()

    return mean_neighborhood_features

def feature_propagation(edge_index, X, feature_mask, num_iterations, alpha=None, n_class=None, k=0, logger=None, ver=0):
    propagation_model = FeaturePropagation(num_iterations=num_iterations, alpha=alpha, n_class=n_class, k=k, logger=logger, ver=ver)

    return propagation_model.propagate(x=X, edge_index=edge_index, mask=feature_mask)

def filling(filling_method, edge_index, X, feature_mask, num_iterations=None, alpha=None, n_class=None, k=0, logger=None, ver=0):
    if filling_method == "random":
        X_reconstructed = random_filling(X)
    elif filling_method == "zero":
        X_reconstructed = zero_filling(X)
    elif filling_method == "mean":
        X_reconstructed = mean_filling(X, feature_mask)
    elif filling_method == "neighborhood_mean":
        X_reconstructed = neighborhood_mean_filling(edge_index, X, feature_mask)
    elif filling_method == "fp":
        X_reconstructed = feature_propagation(edge_index, X, feature_mask, num_iterations, alpha=alpha, n_class=n_class, k=k, logger=logger, ver=ver)
    else:
        raise ValueError(f"{filling_method} method not implemented")
    return X_reconstructed

def compute_mean(X, feature_mask):
    X_zero_filled = X
    X_zero_filled[~feature_mask] = 0.0
    num_of_non_zero = torch.count_nonzero(feature_mask, dim=0)
    mean_features = torch.sum(X_zero_filled, axis=0) / num_of_non_zero
    # If a feature is not present on any node, set it to 0
    mean_features[mean_features.isnan()] = 0

    return mean_features

class FeaturePropagation(torch.nn.Module):
    def __init__(self, num_iterations: int, alpha=None, n_class=None, n_hid=64, k=0, logger=None, ver=0):
        super(FeaturePropagation, self).__init__()
        self.num_iterations = num_iterations
        self.alpha = alpha
        self.n_hid = n_hid
        self.k = k
        self.n_class = n_class
        self.logger = logger
        self.ver = ver

    def propagate(self, x: Tensor, edge_index: Adj, mask: Tensor) -> Tensor:
        # out is inizialized to 0 for missing values. However, its initialization does not matter for the final
        # value at convergence
        out = x
        if mask is not None:
            out = torch.zeros_like(x) # original version
            out[mask] = x[mask]
    
        n_nodes = x.shape[0]
        _, adj = self.get_propagation_matrix(out, edge_index, n_nodes)
        
        if (self.ver != 6) & (self.k > 0):
            for _ in range(1):
                out = torch.sparse.mm(adj, out) # nxf
                # out[mask] = x[mask]

            ## apply knn
            n_feat = out.shape[1]

            ## og버버전
            knn_input = out.T
            row, col, edge_weight = knn_fast(knn_input, self.k)
            
            nan_idx = torch.isnan(edge_weight)
            row, col, edge_weight = row[~nan_idx], col[~nan_idx], edge_weight[~nan_idx]
            edge_index = torch.stack([row, col], 0)

            ## exluding zero column
            # knn_input = out.T
            # knn_mask = knn_input.sum(1) != 0
            # knn_idx_dict = {}
            # idx_new = torch.arange(knn_input.shape[0])[knn_mask]

            # for i in range(len(idx_new)):
            #     knn_idx_dict[i] = idx_new[i].item()

            # row, col, edge_weight = knn_fast(knn_input[knn_mask], self.k)

            # row = torch.tensor([knn_idx_dict[_row.item()] for _row in row], device=out.device)
            # col = torch.tensor([knn_idx_dict[_col.item()] for _col in col], device=out.device)
            # edge_index = torch.stack([row, col], 0)

            _edge_index, _edge_weight = to_undirected(edge_index=edge_index.cpu(), edge_attr=edge_weight.cpu())
            # _edge_index, _edge_weight = to_undirected(edge_index=edge_index.cpu(), edge_attr=edge_weight.cpu(), reduce='mean')
            edge_index, edge_weight = _edge_index.to(edge_index.device), _edge_weight.to(edge_index.device)
            
            reamin_idx = edge_index.sum(0) <= 2*(n_feat - self.n_class) # delete class-class edge
            edge_index = edge_index[:,reamin_idx]
            edge_weight = edge_weight[reamin_idx]

            ## Fora Ablation
            if self.ver == 4:
                f_f_idx = (row < (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                # f_c_idx = (row < (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))
                # c_f_idx = (row >= (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                # f_c_f_idx = f_c_idx + c_f_idx
                # c_c_idx = (row >= (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))

                # with f_f
                if (f_f_idx).sum() > 0:
                    row, col, edge_weight = row[f_f_idx], col[f_f_idx], edge_weight[f_f_idx]
                    edge_index = torch.stack([row, col], 0)

            elif self.ver == 5:
                # f_f_idx = (row < (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                f_c_idx = (row < (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))
                c_f_idx = (row >= (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                f_c_f_idx = f_c_idx + c_f_idx
                # c_c_idx = (row >= (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))

                # with f_c_f
                if (f_c_f_idx).sum() > 0:
                    row, col, edge_weight = row[f_c_f_idx], col[f_c_f_idx], edge_weight[f_c_f_idx]
                    edge_index = torch.stack([row, col], 0)

            ## wo f_c_f_idx
            # row, col, edge_weight = row[~f_c_f_idx], col[~f_c_f_idx], edge_weight[~f_c_f_idx]
            # edge_index = torch.stack([row, col], 0)


            # statistics
            if self.logger:
                idx = row != -1
                # idx = row != col
                row = row[idx]
                col = col[idx]

                f_f_idx = (row < (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                f_c_idx = (row < (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))
                c_f_idx = (row >= (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                f_c_f_idx = f_c_idx + c_f_idx
                c_c_idx = (row >= (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))

                f_f = sum(f_f_idx).item()
                f_c_f = sum(f_c_f_idx).item()
                c_c = sum(c_c_idx).item()
                all = row.shape[0]

                self.logger.info(f'===== k: {self.k} =====')
                self.logger.info(f'# of Feat-Feat edge: {f_f} / {all} = {round(f_f*100/all, 2)}%')
                self.logger.info(f'# of Feat-Class edge: {f_c_f} / {all} = {round(f_c_f*100/all, 2)}%')
                self.logger.info(f'# of Class-Class edge: {c_c} / {all} = {round(c_c*100/all, 2)}%')



            edge_index, edge_weight = get_row_normalized_adjacency(edge_index, n_nodes=n_feat, edge_weight=edge_weight)
            edge_weight.nan_to_num_(0.0)
            
            adj_feat = torch.sparse.FloatTensor(edge_index, values=edge_weight, size=(n_feat, n_feat)).to(edge_index.device)
            
            if self.ver == 10:
                out = out
            elif self.ver == 11:
                out = out
                out[mask] = x[mask]
            else:
                if mask is not None:
                    out = torch.zeros_like(x) # original version
                    out[mask] = x[mask]

            out_feat = out.T

            print('Start Node Propagation ...!')
            # Diffuse current features
            for i in trange(40):
                out_feat = torch.sparse.mm(adj_feat, out_feat)
                out_feat[mask.T] = x.T[mask.T]
            
            out = out_feat.T # initialize out

        # Feature-wise        
        if self.alpha:
            res = (1-self.alpha) * out 
            print('Start Label Propagation ...!')
            for i in trange(self.num_iterations):
                # Label Spreading
                # out = torch.sparse.mm(adj, out)
                # out.mul_(self.alpha).add_(res) 
                # out.clamp_(0., 1.)
                
                # Label Propagation
                out = torch.sparse.mm(adj, out)
                out[mask] = x[mask]

        else:
            # out = F.normalize(out_feat.T)
            print('Start Label / Feature Propagation ...!')
            # Diffuse current features
            for i in trange(self.num_iterations):
                out = torch.sparse.mm(adj, out)
                out[mask] = x[mask]
            

            if self.ver == 6:
                # for _ in range(1):
                #     out = torch.sparse.mm(adj, out) # nxf
                    # out[mask] = x[mask]

                ## apply knns
                # TO-DO: 방향 바꾸기, undirected 로, pca 적용 해보기, k 수 바꾸기, x+y 에서 x 만.
                n_feat = out.shape[1]
                row, col, edge_weight = knn_fast(out.T, self.k)
                
                nan_idx = torch.isnan(edge_weight)
                row, col, edge_weight = row[~nan_idx], col[~nan_idx], edge_weight[~nan_idx]
                edge_index = torch.stack([row, col], 0)

                ## wo f_c_f_idx
                # row, col, edge_weight = row[~f_c_f_idx], col[~f_c_f_idx], edge_weight[~f_c_f_idx]
                # edge_index = torch.stack([row, col], 0)

                # statistics
                if self.logger:
                    idx = row != col
                    row = row[idx]
                    col = col[idx]

                    f_f_idx = (row < (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                    f_c_idx = (row < (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))
                    c_f_idx = (row >= (n_feat - self.n_class)) * (col < (n_feat - self.n_class))
                    f_c_f_idx = f_c_idx + c_f_idx
                    c_c_idx = (row >= (n_feat - self.n_class)) * (col >= (n_feat - self.n_class))

                    f_f = sum(f_f_idx).item()
                    f_c_f = sum(f_c_f_idx).item()
                    c_c = sum(c_c_idx).item()
                    all = row.shape[0]

                    self.logger.info(f'===== k: {self.k} =====')
                    self.logger.info(f'# of Feat-Feat edge: {f_f} / {all} = {round(f_f*100/all, 2)}%')
                    self.logger.info(f'# of Feat-Class edge: {f_c_f} / {all} = {round(f_c_f*100/all, 2)}%')
                    self.logger.info(f'# of Class-Class edge: {c_c} / {all} = {round(c_c*100/all, 2)}%')

                reamin_idx = edge_index.sum(0) <= 2*(n_feat - self.n_class) # delete class-class edge
                edge_index = edge_index[:,reamin_idx]
                edge_weight = edge_weight[reamin_idx]

                # edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight)

                # edge_index, edge_weight = remove_self_loops(edge_index=edge_index.cpu(), edge_attr=edge_weight.cpu())
                # edge_index, edge_weight = to_undirected(edge_index=edge_index.cpu(), edge_attr=edge_weight.cpu(), reduce='mean')
                # edge_index = edge_index.to(x.device)
                # edge_weight = edge_weight.to(x.device)

                ## apply similarity-based
                # out_norm = F.normalize(out, dim=0)
                # XAX = torch.mm(out_norm.T, out_norm)
                # XAX[XAX<0.8] = 0.0
                # n_feat = XAX.shape[0]
                # adj_feat = XAX.to_sparse()

                # edge_index = adj_feat._indices()
                # edge_weight = adj_feat._values()

                edge_index, edge_weight = get_symmetrically_normalized_adjacency(edge_index, n_nodes=n_feat, edge_weight=edge_weight)
                edge_weight.nan_to_num_(0.0)
                
                adj_feat = torch.sparse.FloatTensor(edge_index, values=edge_weight, size=(n_feat, n_feat)).to(edge_index.device)
                
                if mask is not None:
                    out = torch.zeros_like(x) # original version
                    out[mask] = x[mask]
                    
                out_feat = out.T

                print('Start Node Propagation ...!')
                # Diffuse current features
                for i in trange(40):
                    out_feat = torch.sparse.mm(adj_feat, out_feat)
                    out_feat[mask.T] = x.T[mask.T]
                
                out = out_feat.T # initialize out
            
        return out

    def get_propagation_matrix(self, x, edge_index, n_nodes):
        # Initialize all edge weights to ones if the graph is unweighted)

        edge_index, edge_weight = get_symmetrically_normalized_adjacency(edge_index, n_nodes=n_nodes)
        edge_weight_tmp = torch.where(edge_weight > 0, 1, 0).type_as(edge_weight)

        adj_unnormalized = torch.sparse.FloatTensor(edge_index, values=edge_weight_tmp, size=(n_nodes, n_nodes)).to(edge_index.device)
        adj_normalzied = torch.sparse.FloatTensor(edge_index, values=edge_weight, size=(n_nodes, n_nodes)).to(edge_index.device)

        return adj_unnormalized, adj_normalzied
        
    def correct(self, y_soft, y_true, mask, edge_index, edge_weight=None):
        numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
        assert y_true.size(0) == numel

        if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():
            y_true = torch.nn.functional.one_hot(y_true.view(-1), num_classes=y_soft.size(-1))

        error = torch.zeros_like(y_soft)
        error[mask] = y_true - y_soft[mask]

        n_nodes = error.shape[0]
        adj_unnormalized, adj = self.get_propagation_matrix(error, edge_index, n_nodes)

        res = (1-self.alpha) * error
        print('Start Correcting Step ...!')
        for i in trange(50):
            error = torch.sparse.mm(adj, error)
            error.mul_(self.alpha).add_(res)
            error.clamp_(-1., 1.)

        sigma = error[mask].abs().sum() / numel
        scale = sigma / error.abs().sum(dim=1, keepdim=True)
        scale[scale.isinf() | (scale > 1000)] = 1.0

        out = y_soft + scale * error
        
        del error
        del res
        del y_true
        del y_soft
           
        torch.cuda.empty_cache()

        return out

    def smooth(self, y_soft, y_true, mask, edge_index, edge_weight=None):
        numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
        assert y_true.size(0) == numel

        if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():
            y_true = torch.nn.functional.one_hot(y_true.view(-1), num_classes=y_soft.size(-1))

        y_soft = y_soft.clone()
        y_soft[mask] = y_true + 0.0

        n_nodes = y_soft.shape[0]
        adj_unnormalized, adj = self.get_propagation_matrix(y_soft, edge_index, n_nodes)

        res = (1-self.alpha) * y_soft
        print('Start Smoothing Step ...!')
        for i in trange(50):
            y_soft = torch.sparse.mm(adj, y_soft)
            y_soft.mul_(self.alpha).add_(res)
            y_soft.clamp_(0., 1.)

        del res
        del y_true
        torch.cuda.empty_cache()

        return y_soft