import random
import torch
import random
import numpy as np



def generate_random_masks(total_samples, num_true_train, num_true_val, num_true_test, random_seed=None):
    # Set the random seed for reproducibility
    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
    
    # Create random masks
    train_mask = torch.zeros(total_samples, dtype=torch.bool)
    val_mask = torch.zeros(total_samples, dtype=torch.bool)
    test_mask = torch.zeros(total_samples, dtype=torch.bool)

    # Set the specified number of 'True' values randomly in each mask
    indices = list(range(total_samples))
    random.shuffle(indices)

    for i in indices[:num_true_train]:
        train_mask[i] = True

    for i in indices[num_true_train:num_true_train + num_true_val]:
        val_mask[i] = True

    for i in indices[num_true_train + num_true_val:num_true_train + num_true_val + num_true_test]:
        test_mask[i] = True

    return train_mask, val_mask, test_mask


def generate_random_masks_inductive(total_samples, num_train_labeled, num_train_unlabeled, num_val, num_test, random_seed=None):
    # Set the random seed for reproducibility
    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
    
    # Create random masks
    train_labeled_mask = torch.zeros(total_samples, dtype=torch.bool)
    train_unlabeled_mask = torch.zeros(total_samples, dtype=torch.bool)
    val_mask = torch.zeros(total_samples, dtype=torch.bool)
    test_mask = torch.zeros(total_samples, dtype=torch.bool)

    # Set the specified number of 'True' values randomly in each mask
    indices = list(range(total_samples))
    random.shuffle(indices)

    for i in indices[:num_train_labeled]:
        train_labeled_mask[i] = True
    
    for i in indices[num_train_labeled:num_train_labeled + num_train_unlabeled]:
        train_unlabeled_mask[i] = True

    for i in indices[num_train_labeled + num_train_unlabeled: num_train_labeled + num_train_unlabeled + num_val]:
        val_mask[i] = True

    for i in indices[num_train_labeled + num_train_unlabeled + num_val: num_train_labeled + num_train_unlabeled + num_val + num_test]:
        test_mask[i] = True

    return train_labeled_mask, train_unlabeled_mask, val_mask, test_mask


def get_mask_wrt_train(train_labeled_mask, train_unlabeled_mask):
    u = np.arange(train_labeled_mask.shape[0])
    u_mod = u[train_labeled_mask + train_unlabeled_mask]
    u_train_labeled = u[train_labeled_mask]
    u_train_unlabeled = u[train_unlabeled_mask]

    new_train_labeled_mask = [True if i in u_train_labeled else False for i in u_mod]
    new_train_unlabeled_mask = [True if i in u_train_unlabeled else False for i in u_mod]

    return np.array(new_train_labeled_mask), np.array(new_train_unlabeled_mask)



def get_data_mask(data, test_portion, random_seed, data_per_class = 20):
    total_samples = (data.x).shape[0]
    try:
        num_train_labeled, num_val, num_test = (data.train_mask).sum(), (data.val_mask).sum(), int(total_samples*test_portion)
        num_train_unlabeled = total_samples - num_train_labeled - num_val - num_test
        train_labeled_mask, train_unlabeled_mask, val_mask, test_mask = generate_random_masks_inductive(total_samples= total_samples,                      
                                                                    num_train_labeled = num_train_labeled, 
                                                                    num_train_unlabeled = num_train_unlabeled, 
                                                                    num_val = num_val, 
                                                                    num_test = num_test,
                                                                    random_seed=random_seed)
    except:
        num_train_labeled, num_val, num_test = len(data.y.unique())*data_per_class, len(data.y.unique())*data_per_class, int(total_samples*test_portion)
        num_train_unlabeled = total_samples - num_train_labeled - num_val - num_test
        train_labeled_mask, train_unlabeled_mask, val_mask, test_mask = generate_random_masks_inductive(total_samples= total_samples,                      
                                                                    num_train_labeled = num_train_labeled, 
                                                                    num_train_unlabeled = num_train_unlabeled, 
                                                                    num_val = num_val, 
                                                                    num_test = num_test,
                                                                    random_seed=random_seed)
    
    return train_labeled_mask, train_unlabeled_mask, val_mask, test_mask


def evaluate_inductive(y_pred, data, mask_list):
    result = {}
    for (split, mask) in zip(['train_labeled_acc', 'train_unlabeled_acc', 'val_acc', 'test_acc'], mask_list):
        correct = (y_pred[mask] == data.y[mask]).sum()
        acc = int(correct) / int(mask.sum())
        result[split] = acc
    return result