import numpy as np
import os.path as osp
import torch
from torch_geometric.data import NeighborSampler
from torch_geometric.utils.loop import remove_self_loops, add_self_loops, add_remaining_self_loops
from torch_geometric.utils import is_undirected, to_undirected
from torch_geometric.utils import to_dense_adj
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from utils import get_missing_feature_mask, to_dgl, add_labels
from data_utils import get_dataset, set_train_val_test_split
from filling_strategies import filling
from collections import Counter
from utils import get_symmetrically_normalized_adjacency, get_row_normalized_adjacency, knn_fast
from copy import deepcopy
import data_load

class embedder:
    def __init__(self, args, seed, logger=None):
        '''
        assert not (
            args.graph_sampling and args.gnn != "SAGE"
        ), f"{args.model} model does not support training with neighborhood sampling"
        assert not (args.graph_sampling and args.jk), "Jumping Knowledge is not supported with neighborhood sampling"

        if (args.dataset in ['physics', 'OGBN-Arxiv']) & ('batch' not in args.embedder):
            args.scaled = True

        if args.dataset in ['OGBN-Products', 'OGBN-Mag', 'OGBN-Papers100M']:
            args.graph_sampling = True
            args.gnn = 'SAGE'
            args.patience = args.patience_ogbn
            args.print_result = 5
        '''
        device = torch.device(
            f"cuda:{args.device}"
            if torch.cuda.is_available() else "cpu"
        )

        dataset, missing_mask_og = data_load.load_data(name=args.dataset, missing_rate=args.missing_rate, initial_filling=args.initial_filling, seed=seed)
        missing_feature_mask = ~torch.isnan(dataset.x).to(device)

        ## og버버전
        knn_input = deepcopy(dataset.x)
        knn_input = torch.nan_to_num(knn_input, 0).to(device)
        row, col, edge_weight = knn_fast(knn_input, args.k)
        
        nan_idx = torch.isnan(edge_weight)
        row, col, edge_weight = row[~nan_idx], col[~nan_idx], edge_weight[~nan_idx]
        edge_index = torch.stack([row, col], 0)

        _edge_index, _edge_weight = to_undirected(edge_index=edge_index.cpu(), edge_attr=edge_weight.cpu())
        dataset.edge_index, dataset.edge_weight = _edge_index.to(edge_index.device), _edge_weight.to(edge_index.device)

        # adj = to_dense_adj(dataset.data.edge_index)[0]
        split_idx = dataset.get_idx_split() if hasattr(dataset, "get_idx_split") else None

        if args.dataset == 'OGBN-Mag':
            split_idx['train'] = split_idx['train']['paper']
            split_idx['valid'] = split_idx['valid']['paper']
            split_idx['test'] = split_idx['test']['paper']

        # n_nodes, n_features = dataset.data.x.shape        
        # num_classes = dataset.num_classes

        n_nodes, n_features = dataset.x.shape        
        num_classes = len(dataset.y.unique())

        # n_layer = args.num_layers if 'GOODIE' in args.embedder else args.num_layers
        n_layer = args.num_layers
        
        train_loader = (
            NeighborSampler(
                dataset.data.edge_index,
                node_idx=split_idx["train"],
                sizes=[15, 10, 5][:2],
                # sizes=[10],
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=12,
            )
            if args.graph_sampling
            else None
        )

        inference_loader = (
            NeighborSampler(
                dataset.data.edge_index, node_idx=None, sizes=[-1], batch_size=4096, shuffle=False, num_workers=12
            )
            if args.graph_sampling
            else None
        )

        data = set_train_val_test_split(
            seed=seed, data=dataset, split_idx=split_idx, dataset=args.dataset, train_ratio=args.train_ratio
        ).to(device)

        # missing_feature_mask = get_missing_feature_mask(
        #     data=data.x, rate=args.missing_rate, n_nodes=n_nodes, n_features=n_features, type=args.missing_type,
        # ).to(device) # True: known, False: unknown

        x = data.x.clone()
        # x[~missing_feature_mask] = float("nan")
        # missing_feature_mask = (missing_feature_mask * ~data.x.isnan()).bool()

        # For Label Propagation
        lp_mat = None
        if (args.embedder in ['LP', 'LP_label_trick']) or ('GOODIE' in args.embedder):
            y = data.y.clone()
            y = torch.nn.functional.one_hot(y.view(-1)) + 0.0
            if 'OGBN' in args.dataset:
                missing_y_mask = (data.y == -1) # init with False
                missing_y_mask[data.train_mask] = True
                missing_y_mask = missing_y_mask.repeat(1, num_classes)
            else:
                missing_y_mask = data.train_mask.reshape(-1,1).repeat(1, num_classes)
            
            y[~missing_y_mask] = float("nan")
            lp_mat = filling('fp', data.edge_index, y, missing_y_mask, args.num_iterations)
            lp_mat = lp_mat.cpu()

            self.missing_y_mask = missing_y_mask
            self.y = y
            self.lp_mat = lp_mat
            
        # For Feature Propagation
        if ('GOODIE' in args.embedder) & (args.pseudo_type == -1):
            if args.ver == 1:
                # row x col x
                x[~missing_feature_mask] = float("nan")
                
                filled_features = filling('zero', data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k)
                n_features = x.shape[1]

            elif args.ver == 2:
                # row x col o
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)
    
                missing_tmp_mask = torch.ones((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, 0, n_class=num_classes, k=args.k, logger=logger)
                n_features = x.shape[1]

            elif args.ver == 3:
                # row o col x
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_tmp_mask = torch.ones((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=0, logger=logger)
                n_features = x.shape[1]

            elif args.ver == 4:
                # row o col o - feat-feat only
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_tmp_mask = torch.ones((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger, ver=args.ver)
                n_features = x.shape[1]

            elif args.ver == 5:
                # row o col o - feat-class only
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_tmp_mask = torch.ones((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger, ver=args.ver)
                n_features = x.shape[1]

            elif args.ver == 6:
                # row o col o - feat-feat + feat-class, reverse (row-col) To-Do !!
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_tmp_mask = torch.ones((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger, ver=args.ver)
                n_features = x.shape[1]
            
            elif args.ver == 7:
                # chage all
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_tmp_mask = torch.zeros((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger, ver=args.ver)
                n_features = x.shape[1]


            elif args.ver == 8:
                # chage uncertain
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_feature_mask = torch.cat([missing_feature_mask, missing_y_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger, ver=args.ver)
                n_features = x.shape[1]

            elif args.ver == 9:
                # chage uncertain
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)
                lp_mat_max_ind = lp_mat.argmax(1)

                missing_y_mask[torch.arange(lp_mat.shape[0]).to(lp_mat_max_ind.device), lp_mat_max_ind] = True

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_feature_mask = torch.cat([missing_feature_mask, missing_y_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger, ver=args.ver)
                n_features = x.shape[1]

            else:
                lp_mat = lp_mat.to(x.device)
                lp_mat = F.normalize(lp_mat)
                # # Scale x and lp_mat
                # if args.ver in [0, 1]: # norm
                #     lp_mat = F.normalize(lp_mat)
                # elif args.ver in [2,3]: # softmax
                #     # lp_mat = F.softmax(lp_mat, 1)
                #     lp_mat[~data.train_mask] = F.softmax(lp_mat[~data.train_mask] / args.lp_temp, 1)

                scaler = MinMaxScaler(feature_range=(0,1))
                x_tmp = x.nan_to_num(0)
                x = torch.tensor(scaler.fit_transform(x_tmp.cpu()), dtype=torch.float).to(device)
                # lp_mat = scaler.transform(lp.cpu())
                x[~missing_feature_mask] = float("nan")
                x = torch.cat([x, lp_mat], 1)

                missing_tmp_mask = torch.ones((x.shape[0], num_classes)).bool().to(missing_feature_mask.device)
                missing_feature_mask = torch.cat([missing_feature_mask, missing_tmp_mask], 1)
                filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations, n_class=num_classes, k=args.k, logger=logger)
                n_features = x.shape[1]

        elif (args.pseudo_type >= 0) or (args.embedder in ['GNN']):
            filled_features = filling(args.filling_method, data.edge_index, x, missing_feature_mask, args.num_iterations)
            x = torch.where(missing_feature_mask, x, filled_features)

        else:
            filled_features = torch.full_like(x, float("nan"))

        # x = torch.where(missing_feature_mask, x, filled_features)

        ## Refined kNN
        # knn_input = deepcopy(x)
        # knn_input = torch.nan_to_num(knn_input, 0).to(device)
        # row, col, edge_weight = knn_fast(knn_input, args.k)
        
        # nan_idx = torch.isnan(edge_weight)
        # row, col, edge_weight = row[~nan_idx], col[~nan_idx], edge_weight[~nan_idx]
        # edge_index = torch.stack([row, col], 0)

        # _edge_index, _edge_weight = to_undirected(edge_index=edge_index.cpu(), edge_attr=edge_weight.cpu())
        # dataset.edge_index, dataset.edge_weight = _edge_index.to(edge_index.device), _edge_weight.to(edge_index.device)


        if args.embedder == 'TWIRLS':
            data.x = x
            ogb = True if 'OGBN' in args.dataset else False
            self.dgl_graph = to_dgl(data, ogb=ogb).to(device)
        
        else:
            if args.embedder == 'node2vec_gnn_concat_1':
                n_features += args.hidden_dim
            if args.label_trick:
                n_features += num_classes

        evaluator = None

        self.x = x
        self.x_og = x if args.label_trick else None
        # self.adj = adj
        self.edge_index = data.edge_index
        self.edge_weight = data.edge_weight
        self.evaluator = evaluator
        self.train_loader = train_loader
        self.inference_loader = inference_loader

        self.train_mask = data.train_mask
        self.val_mask = data.val_mask
        self.test_mask = data.test_mask
        self.labels = data.y
        self.missing_feature_mask = missing_feature_mask


        args.n_nodes = n_nodes
        args.n_feat = n_features
        args.n_hid = args.hidden_dim
        args.n_class = num_classes
        args.n_layer = n_layer
        args.n_head = args.num_heads

        self.args = args