import os
import json
import random
import torch
import torch.nn.functional as F

'''
model_pth/
    {architecture}/
        {dataset}/
            {edge_ratio}/
                {random_seed}/
                    base_{train_num}.pth
                    pp_{pp_method}/
                        processed_{train_num}_{processed_num}.pth
'''

'''
results/
    {architecture}/
        {dataset}/
            {edge_ratio}/
                {random_seed}/
                    pp_{pp_method}/
                        add.jpg
                        remove.jpg
                        combined.jpg
                        result.json
                        results_mean_std.csv
'''

def get_base_model_path(args, iteration):
    """Generate the save path for base models."""
    return f'model_pth/{args.architecture}/{args.dataset}/{args.edge_ratio}/{args.random_seed}/base_{iteration}.pth'

def get_processed_model_path(args, train_num, processed_num):
    """Generate the save path for processed models."""
    return f'model_pth/{args.architecture}/{args.dataset}/{args.edge_ratio}/{args.random_seed}/pp_{args.pp_method}/processed_{train_num}_{processed_num}.pth'

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def no_activation(x):
    return x

def load_hyperparameters(args, data):
    architecture = args.architecture
    dataset = args.dataset

    if args.dataset in ['ogbn-arxiv', 'flickr', 'reddit']:
        with open('model/params_large.json', 'r') as f:
            params = json.load(f)
    else:
        with open('model/params.json', 'r') as f:
            params = json.load(f)

    # Start with common parameters
    hyperparameters = params["common"].copy()

    # Update with specific architecture parameters
    if architecture in params:
        hyperparameters.update(params[architecture])
    
    if architecture == "sgc":
        hyperparameters["hidden_dim"] = data.num_classes

    if hyperparameters['activation'] == 'relu':
        hyperparameters['activation'] = F.relu
    elif hyperparameters['activation'] == 'elu':
        hyperparameters['activation'] = F.elu
    elif hyperparameters['activation'] == 'sigmoid':
        hyperparameters['activation'] = F.sigmoid
    elif hyperparameters['activation'] == 'tanh':
        hyperparameters['activation'] = F.tanh
    elif hyperparameters['activation'] == 'none':
        hyperparameters['activation'] = no_activation
        
    return hyperparameters

def set_device():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    return device

def create_directory(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Function to generate 'n' random numbers using a given seed
def generate_random_numbers(n, seed=None):
    # Set the random seed if provided
    if seed is not None:
        set_seed(seed)
    # Generate 'n' random numbers and store them in a list
    random_numbers = [random.randint(0, 2**32 - 1) for _ in range(n)]
    
    return random_numbers