import torch
import random
import numpy as np

def generate_requests(env, seed = 42, num_requests = 10, num_sequences = 1): 
    

    random.seed(seed)
    np.random.seed(seed)
    state_batch = torch.empty((0, env.num_servers+1))
    for i in range(num_sequences):
        state = torch.LongTensor(random.sample(range(0, env.num_nodes), env.num_servers+1))
        sorted_state = torch.sort(state[:env.num_servers])[0]
        state = torch.cat([sorted_state, state[-1].unsqueeze(0)]) 
        state_batch = torch.cat((state_batch, state.unsqueeze(0)), 0)
    # requests = state_batch[:, -1].reshape(env.batch_size, 1)
 

    if env.request_same_node:
        requests = torch.tensor(np.random.choice(list(env.graph.nodes), size=(num_sequences, num_requests), p=env.probabilities, replace=True)).to(env.device)
    else: 
        requests = np.random.choice(list(env.graph.nodes), size=(num_sequences,), p=env.probabilities, replace=True)
        for i in range(num_sequences):
            while requests[i] in state_batch[i]:
                    requests[i] = np.random.choice(list(env.graph.nodes), size=1, p=env.probabilities)
        requests = torch.tensor(requests).reshape(num_sequences, 1).to(env.device)
        
        for i in range(num_requests):
            random_int = np.random.choice(list(env.graph.nodes), size=(num_sequences,), p=env.probabilities, replace=True)
            for i in range(num_sequences):
                while random_int[i] in requests[i][-1]:
                    random_int[i] = np.random.choice(list(env.graph.nodes), size=1, p=env.probabilities)
            random_int = torch.tensor(random_int).reshape(num_sequences, 1).to(env.device)
            requests = torch.cat((requests, random_int), dim=1)

    return requests, state_batch.to(env.device)