import argparse
import sys
import os.path as osp
import random
import numpy as np
from time import perf_counter as t
import argparse
import torch
import torch_geometric.transforms as T
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import GCNConv
from torch import Tensor
from model import Encoder, Model, drop_feature, drop_feature_fair
import networkx as nx
from utils import load_pokec, sens_classification
from eval import label_classification


def test(model: Model, x, edge_index, y,sens):
    model.eval()
    z = model(x, edge_index)

    label_classification(z, y, sens, ratio=0.1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda', action='store_true', default=True,
                        help='Disables CUDA training.')
    parser.add_argument('--input', type=str, default='False')
    parser.add_argument('--labels', type=str, default='False')
    parser.add_argument('--sens', type=str, default='False')
    parser.add_argument("--seed",type=int, default=39788, help='Random seed.')
    args = parser.parse_known_args()[0]
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(8)
    embed_file=args.input
    label_file=args.labels
    sens_file=args.sens

    ordered_embeds=np.load(embed_file)
    ordered_embeds=np.reshape(ordered_embeds,(np.shape(ordered_embeds)[1],np.shape(ordered_embeds)[2]))   #Uncomment this if the embeddings created from DGI are tested
    ordered_embeds=torch.FloatTensor(ordered_embeds)
    #from utils import feature_norm
    #ordered_embeds = feature_norm(ordered_embeds)
    labels=np.load(label_file)
    labels=torch.LongTensor(labels)
    sens=np.load(sens_file)
    sens=torch.LongTensor(sens)
    
    device = torch.device('cuda' if args.cuda else 'cpu')
    
    ordered_embeds= ordered_embeds.to(device)
    labels=labels.to(device)
    sens=sens.to(device)


    label_classification(ordered_embeds, labels, sens, 0.1) 
    #sens_classification(ordered_embeds, sens, 0.1)
