class Dict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

args = Dict(
    data="Squirrel",
    gpu=3,
    missing_rate=0,
    categorical=False,
    verbose=False,
    num_epochs=200,
    vdim=64,
    units=200,
    depth=3,
    homo=False,
    attn_temp=0.01,
    drop_rate=0.,
    #edge_value_thresh=0.01,
    imputation_method='zero',
    lr=1e-3,
    weight_decay=1e-6,
    sim_batch=20,
    top_k_feature_similarity=1,#10,
)

#import os
#os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
#torch.use_deterministic_algorithms(True)

#
import os
import sys
import time
import torch
from torch import nn
import torch_geometric as pyg
from torch_geometric import nn as gnn
#from torch_geometric.loader import NeighborSampler
from tqdm import tqdm, trange
import numpy as np

from sklearn.decomposition import PCA, FastICA

import torch.nn.functional as F
import random
import copy
import argparse
#src_dir = os.path.dirname(os.path.dirname(__file__))
#sys.path.append(src_dir)
from utils.data_loader import load_data
from utils.utils import seed_everything,create_otf_edges,get_feature_mask
# from gnn_utils import GraphAnalyzer
from sklearn.metrics import pairwise_distances

seed_everything(0)



#parser = argparse.ArgumentParser()
#parser.add_argument("--data", default="Cora", help="name of the dataset", type=str)
#parser.add_argument("--gpu", default=0, help="GPU no. to use, -1 in case of no gpu", type=int)
#parser.add_argument("--missing_rate",default=0,help="ratio of features to be missed randomly", type=float)
#parser.add_argument("--categorical",default=False,help="Make edges only when feature is present/categorical", type=bool)
#parser.add_argument("--verbose",default=False,help="Print Model output during training", type=bool)
#parser.add_argument("--num_epochs",default=200,help="Print Model output during training", type=int)
#parser.add_argument("--num_layers",default=1,help="Num of layers (1,2)", type=int)
#parser.add_argument("--bs_train_nbd",default=512,help="Num of nodes in training computation subgraph", type=int)
#parser.add_argument("--bs_test_nbd",default=-1,help="Num of nodes in testing computation subgraph", type=int)
#parser.add_argument("--drop_rate",default=0.2,help="Drop rate", type=float)
#parser.add_argument("--result_file",default='result.txt',type=str)
#parser.add_argument("--edge_value_thresh",default=0.01,type=float)
#parser.add_argument("--imputation",default='zero',type=str)
#parser.add_argument("--heads",default=4,type=int)
#parser.add_argument("--weight_decay",default=0,type=float)

#args = parser.parse_args()

device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
data = load_data(args.data,train_ratio=0.6,val_ratio=0.2)
print("train dataset, val dataset and test dataset ", data.train_mask.sum(),data.val_mask.sum(),data.test_mask.sum())
    
# if args.missing_rate > 0:
#     print("missing rate,", args.missing_rate)
#     feature_mask = get_feature_mask(args.missing_rate,data['x'].shape[0],data['x'].shape[1])
#     # print("feature mask size: ", feature_mask.shape)
#     data['x'][~feature_mask] = float('nan')  ### replaced values with nan
#     if args.imputation_method=='zero':
#         X_reconstructed = torch.zeros_like(data['x'])
#     if args.imputation_method == 'nf':
#         print("Neighbourhood mean")
#         X_reconstructed = neighborhood_mean_filling(data.edge_index,data.x,feature_mask)
#     if args.imputation_method == 'fp':
#         print("Feature propogation")
#         X_reconstructed = feature_propagation(data.edge_index,data.x,feature_mask,50)  
#     #X_reconstructed = feature_propagation(data.edge_index,data.x,feature_mask,50)#
#     data['x'] = torch.where(feature_mask, data.x, X_reconstructed)
# else:
feature_mask = torch.ones_like(data['x']).bool()
# print("data x shape: ", data['x'].shape)
# print("Remaining edges ", feature_mask.sum(), data['x'].shape[0] * data['x'].shape[1])

#num_samples = [20,15]

#if bs_train_nbd == -1:
#    bs_train_nbd = data.x.shape[0]

#if bs_test_nbd == -1:
#    bs_test_nbd = data.x.shape[0]
    
#print("bs_train_nbd and test_nbd", bs_train_nbd,bs_test_nbd)
#train_neigh_sampler = NeighborSampler(data.edge_index, node_idx= data.train_mask, sizes=num_samples, batch_size=bs_train_nbd, shuffle=True, num_workers=0)
#subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1,-1], batch_size=bs_test_nbd, shuffle=False, num_workers=0)



n_cls = data.y.max().item() + 1
print("Dataset: ", args.data)
print(f"Node Feature Matrix Info: # Nodes: {data.x.shape[0]}")
print(f"Node Feature Matrix Info: # Node Features: {data.x.shape[1]}")
print(f"Edge Index Shape: {data.edge_index.shape}")
print(f"Edge Weight: {data.edge_attr}")
print(f"# Classes: {n_cls}")


with torch.no_grad():
    feature_mask = feature_mask.to(device)
    n_nodes, n_feats = data.x.shape
    dtype = torch.float
    y = data.y.to(device)

    # rademacher_matrix = torch.randint(0, 2, (n_feats, args.vdim), dtype=torch.float32) * 2 - 1
    # random_projection_matrix = torch.randn(n_feats, args.vdim) / (args.vdim ** 0.5)

    # Perform random projection using matrix multiplication
    # node_x = (data.x @ random_projection_matrix).to(dtype).to(device)
    node_x = torch.ones((n_nodes, args.vdim), dtype = dtype, device = device) * 0.
    # node_x = torch.ones((n_nodes, n_nodes), dtype = dtype, device = device) * 0.5
    # real-node feature matrix: all with value 0.5
    # node_x = torch.ones((n_nodes, args.vdim), dtype = dtype, device = device) * 0.5
    # feature-node feature matrix: random numbers
    
    # ica = FastICA(n_components=args.vdim, random_state=0, whiten='unit-variance')
    
    # feat_x = torch.from_numpy(pca.fit_transform(torch.nn.functional.normalize(data.x.T))).to(dtype).to(device)
    pca = PCA(n_components=args.vdim)
    feat_x = torch.from_numpy(pca.fit_transform(data.x.T)).to(dtype).to(device)
    # feat_x = torch.from_numpy(ica.fit_transform(data.x.T)).to(dtype).to(device)
    # feat_x = torch.nn.functional.normalize(feat_x).to(dtype).to(device)
    # feat_x = torch.rand((n_feats, args.vdim), dtype = dtype, device = device)
    # (# node + # feature) * v_dim
    x = node_x
    #x = torch.cat([node_x, feat_x], dim = 0)
    # edges between real-nodes and real-nodes (node-node edges)
    node_node = data.edge_index.to(device)
    # 
    node_feat_unidir = torch.stack(torch.nonzero(data.x.to(device), as_tuple = True), dim = 0)

    # real node id -> feature node id (feature id + real node num)
    # re-order the feature nodes!
    node_feat_unidir[1] += n_nodes
    # add both directions of the (real node)-(feature node) edges
    node_feat = torch.cat([node_feat_unidir, torch.stack([node_feat_unidir[1], node_feat_unidir[0]], dim = 0)], dim = 1)


    # print('before jaccard', flush = True)
    x_ = (data.x.T.to(device) > 0).bool()
    jaccard_sim = []
    for i0 in trange(0, x_.shape[0], args.sim_batch):
        i1 = min(i0 + args.sim_batch, x_.shape[0])
        jaccard_sim.append((x_[None] & x_[i0 : i1, None]).sum(dim = 2) / (x_[None] | x_[i0 : i1, None]).sum(dim = 2))
    jaccard_sim = torch.cat(jaccard_sim, dim = 0)
    jaccard_sim.fill_diagonal_(0.)
    # print('after jaccard', flush = True)
    #jaccard_sim = 1 - pairwise_distances((data.x.T > 0).bool(), metric='jaccard', n_jobs=-1)
    # exclude the feature itself
    #np.fill_diagonal(jaccard_sim, 0)
    # get the top-k closest features to add feature-feature edges
    feat_feat_unidir = torch.zeros(2, n_feats * args.top_k_feature_similarity)
    if args.top_k_feature_similarity:
        top_k_jaccard = torch.topk(jaccard_sim, k=args.top_k_feature_similarity)
        #top_k_jaccard = torch.topk(torch.from_numpy(jaccard_sim), k=args.top_k_feature_similarity)
        feat_feat_unidir[0] = torch.arange(n_feats).repeat_interleave(args.top_k_feature_similarity)
        feat_feat_unidir[0] += n_nodes
        feat_feat_unidir[1] = top_k_jaccard.indices.flatten()
        feat_feat_unidir[1] += n_nodes
    feat_feat_unidir = feat_feat_unidir.long()
    feat_feat = torch.cat([feat_feat_unidir, torch.stack([feat_feat_unidir[1], feat_feat_unidir[0]], dim = 0)], dim = 1).to(device)
    

    # add together node-node edges and node-feature edges
    # edges = torch.cat([node_node, node_feat], dim = 1)
    edges = torch.cat([node_node, node_feat, feat_feat], dim = 1)

    # print(edges.shape)
    # print(np.sum(jaccard_sim))
    # construct feature-feature edges
    # graph_analyzer = GraphAnalyzer(node_feat_unidir)
    # all_jaccard_indices = graph_analyzer.compute_all_jaccard_indices()

    # for node, similar_nodes in top_similar_nodes.items():
    #     print(f"Node {node}:")
    #     for similar_node, similarity in similar_nodes:
    #         print(f"  - Node {similar_node}: Similarity {similarity:.4f}")

    # sp_adj = torch.sparse_coo_tensor(indices = edges, values = torch.ones(edges.size(1), dtype = dtype, device = device), size = (x.size(0),) * 2, device = device)
    # deg = torch.sparse.mm(sp_adj, torch.ones_like(x[:, : 1]))
    # inv_deg = torch.where(deg > 0, 1 / deg, torch.zeros_like(x[0, 0]))
    # node_node_mask = torch.cat([torch.ones(node_node.size(1), dtype = dtype, device = device), torch.zeros(node_feat.size(1), dtype = dtype, device = device)], dim = 0)
    # w = torch.stack([node_node_mask, 1. - node_node_mask], dim = 1)

    # node_node_mask = torch.cat([torch.ones(node_node.size(1), dtype = dtype, device = device), torch.zeros(node_feat.size(1), dtype = dtype, device = device)], dim = 0)
    # print(node_node_mask.shape)
    # print(node_node_mask)

    edge_types = torch.cat([torch.zeros(node_node.size(1), dtype = dtype, device = device), 
                      torch.ones(node_feat.size(1), dtype = dtype, device = device),
                      2*torch.ones(feat_feat.size(1), dtype = dtype, device = device)], dim = 0).long()
    
    # print(edge_types)
    w = torch.zeros(edges.shape[1], 3).to(device)
    w.scatter_(dim=1, index=edge_types.unsqueeze(1), value=1)
    # print(w)
    # print(w.shape)

torch.cuda.empty_cache()

#import gc 
class MLP(nn.Module):
    def __init__(self, units_list, act_fn):
        super().__init__()
        self.units_list = units_list
        self.depth = len(self.units_list) - 1
        self.act_fn = act_fn
        self.lins = nn.ModuleList([nn.Linear(self.units_list[i], self.units_list[i + 1]) for i in range(self.depth)])
    def forward(self, x):
        for i in range(self.depth):
            x = self.lins[i](x)
            if i < self.depth - 1:
                x = self.act_fn(x)
        return x

class GNNAggr(gnn.MessagePassing):
    def __init__(self, aggr):
        super().__init__(aggr = aggr)
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x = x)
    def message(self, x_i, x_j):
        return x_j

import torch_scatter
class GNN(nn.Module): # Anisotropic GNN + MLP
    def __init__(self, x_feat, homo, attn_temp, vdim, edim, units, depth, act_fn, v_mlp, e_mlp = None):
        super().__init__()
        self.x_feat = nn.Parameter(x_feat)
        self.homo = homo
        self.attn_temp = attn_temp
        self.vdim = vdim
        self.edim = edim
        self.units = units
        self.depth = depth
        self.act_fn = act_fn
        if self.vdim != self.units:
            self.v_lin0 = nn.Linear(self.vdim, self.units)
        self.v_lins1 = nn.ModuleList([nn.Linear(self.units, self.units) for i in range(self.depth)])
        self.v_lins2 = self.v_lins1 if self.homo else nn.ModuleList([nn.Linear(self.units, self.units) for i in range(self.depth)])
        self.v_lins3 = nn.ModuleList([nn.Linear(self.units, self.units) for i in range(self.depth)])
        self.v_lins4 = self.v_lins3 if self.homo else nn.ModuleList([nn.Linear(self.units, self.units) for i in range(self.depth)])
        self.v_bns = nn.ModuleList([gnn.BatchNorm(self.units) for i in range(self.depth)])
        if self.edim != self.units:
            self.e_lin0 = nn.Linear(self.edim, self.units)
        self.e_lins = nn.ModuleList([nn.Linear(self.units, self.units) for i in range(self.depth)])
        self.e_bns = nn.ModuleList([gnn.BatchNorm(self.units) for i in range(self.depth)])
        self.w_lins = nn.ModuleList([nn.Linear(self.units, self.units) for i in range(self.depth)])
        self.v_mlp = MLP([self.units] + list(v_mlp), act_fn)
        if e_mlp is None:
            self.e_mlp = None
        else:
            self.e_mlp = MLP([self.units] + list(e_mlp), act_fn)
    def forward(self, x, w, edge_index):
        x = torch.cat([x, self.x_feat], dim = 0)
        with torch.no_grad():
            sp_batch = torch.arange(self.units, dtype = edge_index.dtype, device = x.device).repeat_interleave(edge_index.size(1))[None]
            sp_edges = torch.stack([edge_index[1], edge_index[0]], dim = 0).repeat(1, self.units)
            sp_index = torch.cat([sp_batch, sp_edges], dim = 0)
            #sp_index = (sp_index[:, :, None] * self.units + torch.arange(self.units, dtype = edge_index.dtype, device = x.device)).view(2, -1)
        ones = torch.ones((self.units, x.size(0), 1), dtype = x.dtype, device = x.device)
        if self.vdim != self.units:
            x = self.v_lin0(x)
            x = self.act_fn(x)
        if self.edim != self.units:
            w = self.e_lin0(w)
            w = self.act_fn(w)
        for i in range(self.depth):
            x0, w0 = x, w
            if args.drop_rate:
                x0 = F.dropout(x0, p = args.drop_rate, training = self.training)
                w0 = F.dropout(w0, p = args.drop_rate, training = self.training)
            x1 = self.v_lins1[i](x0)
            x2 = self.v_lins2[i](x0)
            w0 = self.w_lins[i](w0)
            w1 = pyg.utils.softmax(F.logsigmoid(w0) / self.attn_temp, index = edge_index[1]) if self.attn_temp else torch.sigmoid(w0)
            #delta_x = torch.bmm(torch.sparse_coo_tensor(indices = sp_index, values = (w1 * x2[edge_index[0]]).T.flatten(), size = (self.units, ones.size(1), ones.size(1))), ones)[:, :, 0].T # note: edge_index[0] -> edge_index[1]
            #x = x0 + self.act_fn(self.v_bns[i](x1 + delta_x * inv_deg)) # mean aggregation
            #s = torch_scatter.scatter(src = w1 * x2[edge_index[0]], index = edge_index[1], dim = 0, reduce = 'mean')
            # print(w1.shape)
            # print(x2.shape)
            # print(edge_index[0].shape)
            # print(edge_index[1].shape)
            # print(s.shape)
            x = x0 + self.act_fn(self.v_bns[i](x1 + torch_scatter.scatter(src = w1 * x2[edge_index[0]], index = edge_index[1], dim = 0, reduce = 'sum', dim_size = x1.size(0))))
            if i == self.depth - 1 and  self.e_mlp is not None:
                x3 = self.v_lins3[i](x0)
                x4 = self.v_lins4[i](x0)
                w2 = self.e_lins[i](w0)
                w = w0 + self.act_fn(self.e_bns[i](w2 + x3[edge_index[0]] + x4[edge_index[1]]))
        x = self.v_mlp(x)
        x = F.log_softmax(x, dim = -1) # classification
        if self.e_mlp is None:
            return x
        else:
            w = self.e_mlp(w)
            return x, w

model = GNN(x_feat = feat_x, homo = args.homo, attn_temp = args.attn_temp, vdim = args.vdim, edim = w.size(-1), units = args.units, depth = args.depth, act_fn = F.relu, v_mlp = [n_cls])
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay = args.weight_decay)

model_save_path = "models"
from pathlib import Path
Path(model_save_path).mkdir(parents=True, exist_ok=True)
actual_test_acc = 0
best_val_acc = 0
best_epoch = 0
#numPosSamples = 64


from tqdm import tqdm, trange
tbar = trange(1, args.num_epochs + 1)
start = time.time()
perf_time_list = []
for epoch in tbar:
    model.train()
    optimizer.zero_grad()
    out = model(x = x, w = w, edge_index = edges)[: n_nodes]
    loss = F.nll_loss(out[data.train_mask], y[data.train_mask])
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        if args.drop_rate > 0:
            model.eval()
        out = model(x = x, w = w, edge_index = edges)[: n_nodes]
        res = (out.argmax(dim = -1) == y).float()
        tr_acc = res[data.train_mask].mean().item()
        va_acc = res[data.val_mask].mean().item()
        te_acc = res[data.test_mask].mean().item()
        if va_acc > best_val_acc:
            best_val_acc = va_acc
            actual_test_acc = te_acc
            best_epoch = epoch

    end = time.time()
    perf_time_list.append((end-start, actual_test_acc * 100.0))
    tbar.set_description(f'[epoch={epoch}] loss={loss.item():.4f} tr_acc={tr_acc:.4f} va_acc={va_acc:.4f} te_acc={te_acc:.4f} result={actual_test_acc:.4f}')


print("Test accuracy,", actual_test_acc)

with open('result.txt', "w") as f:
    f.write('args: ' + repr(args))
    f.write('test acc: ' + str(actual_test_acc) + "\n")



# import pickle
# pickle.dump(perf_time_list, open(f"{args.data}_perf_time_list.pkl", "wb"))
