import random
import torch
import numpy as np
from tabulate import tabulate

from csbm_s import csbms_graph


def load_graph(args):
    if args.dataset in ['csbms']:
        graph = csbms_graph(args)

    return graph


def split_per_label_ratio(graph, train_ratio, val_ratio):
    num_nodes = graph.x.size(0)
    num_labels = int(graph.y.max() + 1)

    nodes = torch.arange(num_nodes)
    train_indices = []
    val_indices = []
    test_indices = []

    for i in range(num_labels):
        nodes_i = nodes[graph.y == i]
        nodes_i = nodes_i[random.sample(range(nodes_i.shape[0]), nodes_i.shape[0])]
        num_train = int(nodes_i.shape[0] * train_ratio)
        num_val = int(nodes_i.shape[0] * val_ratio)

        half_nodes = int(nodes_i.shape[0] / 2)
        quarter_nodes = int(nodes_i.shape[0] / 4)

        train_indices.append(nodes_i[0 : half_nodes][:num_train])
        val_indices.append(nodes_i[half_nodes : half_nodes+quarter_nodes][:num_val])
        test_indices.append(nodes_i[half_nodes+quarter_nodes : ])

    train_idx = torch.cat(train_indices)
    val_idx = torch.cat(val_indices)
    test_idx = torch.cat(test_indices)

    return train_idx, val_idx, test_idx


def process_data(args, graph):
    graph = graph.clone()

    """
    data split
    """
    if args.split_type == 'ratio':
        split = split_per_label_ratio(graph, train_ratio=args.train_ratio, val_ratio=args.val_ratio)
    train_nodes, val_nodes, test_nodes = split

    """
    to GPU
    """
    graph = graph.to(args.device)

    return graph, train_nodes, val_nodes, test_nodes


def print_outcome(args, exp, test_acc, sp, eo, hc, print_summary):
    if print_summary:
        test_acc = np.mean(test_acc)
        sp = np.mean(sp)
        eo = np.mean(eo)
        hc = np.mean(hc)
        exp = int(exp + 1)

        print('\n Results Summary:')
        print(tabulate([[exp, test_acc, sp, eo, hc]], 
                       headers=['Num-Trials', 'Test Acc', 'sp', 'eo', 'hc'], 
                       tablefmt='orgtbl',
                       stralign='center',
                       numalign='center',
                       floatfmt='.4f',
                       ))
        print('\n')
        
    else:
        if args.dataset in ['csbms']:
            if exp == 0: print(tabulate([['Trial', 'Test Acc', 'sp', 'eo', 'hc', 'tau', 'mu_y', 'mu_s', 'n']], tablefmt='orgtbl'))
            print(tabulate([[exp, test_acc, sp, eo, hc, args.tau, args.mu_y, args.mu_s, args.num_nodes]], 
                           tablefmt='orgtbl', 
                           stralign='center', 
                           numalign='center', 
                           floatfmt='.4f',
                           ))
