from cmath import inf
import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import convert
import torch_geometric
from Policies.NET import Net
import neptune
import csv
import os
import time

from Policies.GraphDataset import MyData
from torch_geometric.utils import to_dense_adj



class GCN_RL():

    def __init__(self, env, seed = 42, gamma=0.99, lr=0.001, num_layers = 12, memory_size=10000, hidden_channels = 128, hidden_channels_node = 32, est_pr_acc = False, shared_weights = False, gen = False, use_batch_norm = False, tr_dist = False, tr_att = False, dir_graph = False):

        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.gen = gen
        self.env = env
        # print(self.env.probabilities)
        # print(self.env.cost_matrix)
        self.lr = lr
        self.batch_size = self.env.batch_size   
        self.device = self.env.device                          
        self.memory = deque(maxlen=memory_size)                 
        self.gamma = gamma 
        self.uniform_random = self.env.uniform_random           
        self.constant_probability = self.env.constant_probability   
        print(self.constant_probability) 
        self.var_distance = self.env.var_distance 
        self.tr_dist = tr_dist
        self.tr_att = tr_att
        self.dir_graph = dir_graph
        # if 'dir' in self.env.graph_type: 
        #     self.dir_graph = True
        # else: 
        #     self.dir_graph = False
        if self.var_distance == False and self.tr_dist == True: 
            raise ValueError("Unfeasible configuration")
        if self.tr_dist == False and self.tr_att == True: 
            raise ValueError("Unfeasible configuration")
        # self.var_distance = False 
        self.est_pr_acc = est_pr_acc
        self.hidden_channels = hidden_channels

        if self.uniform_random == True and self.constant_probability== False: 
            raise ValueError("Unfeasible configuration")
        # self.num_layers = int(abs(torch.quantile(GreedyPolicy(KServerEnv(num_nodes=self.env.num_nodes, num_servers = self.env.num_servers, graph_type= self.env.graph_type, device=self.device)).estimate(20)[3], 0.01, interpolation='lower')))
        self.num_layers = num_layers

        if self.uniform_random:
            self.in_channells = 2
        else: 
            self.in_channells = 3
        # Initialize action-value funciton Q
        self.q_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels,  num_layers = self.num_layers, num_nodes = self.env.num_nodes, shared_weights = shared_weights, use_batch_norm = use_batch_norm, var_distance = self.var_distance, tr_dist = tr_dist, tr_att = tr_att, dir_graph = self.dir_graph).to(self.device)
        # Initialize target action-value function Q'
        model_path = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}.pth'
        # if os.path.exists(model_path):
        #     self.q_network.load_state_dict(torch.load(model_path))
        self.target_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels,  num_layers = self.num_layers, num_nodes = self.env.num_nodes, shared_weights = shared_weights,  use_batch_norm = use_batch_norm, var_distance = self.var_distance, tr_dist = tr_dist, tr_att = tr_att, dir_graph = self.dir_graph).to(self.device)
        # every
        self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) 
        # total rewards for separate sets 
        self.total_reward =[] 
        # self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        # edge index
        self.edge_index = convert.from_networkx(self.env.graph).edge_index
        # epsilon
        self.max_epsilon=1
        self.min_epsilon=0.05 

       

    
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()
        # Print the weights of the network
        for name, param in network_weights.items():
            # if "W_dist" in name: 
                print(f"Layer: {name}\nWeights: {param}")

    
    def total_params(self):
        return sum(p.numel() for p in self.q_network.parameters())
        
    def observation_formation(self, state, node_pbs = 1, constant_probability = True): 
        
        if self.uniform_random:
            X = torch.zeros(self.env.batch_size, 2, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                # use row i of qt_index as an index to extract a row from q_values
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
        
        else: 
            X = torch.zeros(self.env.batch_size, 3, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
                if constant_probability: 
                    if self.est_pr_acc: 
                        X[i][2] = self.pr_acc/self.pr_acc.sum()
                    else:
                        X[i][2] = torch.FloatTensor(self.env.probabilities).view(1, -1)
                else:
                    X[i][2] = node_pbs[i]

                if self.env.arrival_rates: 
                    X[i][2] = X[i][2]* self.env.num_nodes
                
                
        if self.dir_graph:
            
            X = torch.transpose(X, -1, -2)
            im = torch.eye(self.env.num_nodes).to(self.device)                
            if self.var_distance: 
                # cm  = self.env.cost_matrix.unsqueeze(0)
                cm  = self.env.cost_matrix_nn.unsqueeze(0)
                edge_index = cm + im 
                if self.tr_dist == False:
                    edge_index = 1/edge_index
            else:
                edge_index = to_dense_adj(self.edge_index).to(self.device)                
                edge_index = edge_index + im
                
            edge_index = edge_index.repeat(self.batch_size, 1, 1)
            data = MyData(X, edge_index)

            return data 

        else:   
            if self.var_distance: 
                # Flatten the cost matrix into a 1D tensor
                # print(self.env.cost_matrix.size())
                # cost_vector = self.env.cost_matrix.view(-1)
                cost_vector = self.env.cost_matrix_nn.view(-1)
                # print(cost_vector.size())

                # Compute the linear indices corresponding to the edges in the edge_index
                linear_indices = self.edge_index[0] * self.env.cost_matrix.size(1) + self.edge_index[1]

                # Extract edge weights from the cost vector using linear indices
                edge_weight = cost_vector[linear_indices]
                # edge_weight = 1/self.env.cost_matrix[self.edge_index[0], self.edge_index[1]]
            data_list = []
            for i in range(self.batch_size):
                x = X[i] 
                if self.var_distance: 
                    res = Data(x = x.T, edge_index = self.edge_index, edge_weight = edge_weight)
                else: 
                    res = Data(x = x.T, edge_index = self.edge_index)
                data_list.append(res)
            train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
            data = next(iter(train_loader))
            return data

        




    def get_action(self, state, epsilon=0.1, failsafe = False):
        # with probability epsilon 
        if random.random() < epsilon:
            
    
            random_indices = torch.randint(low = 0, high=self.env.num_servers, size=(self.batch_size,)).to(self.device)
            action_batch = torch.gather(state[:, :self.env.num_servers], dim=1, index=random_indices.unsqueeze(1))
            if self.env.request_same_node: 
                for i in range(self.env.batch_size):
                    if state[i][-1] in state[i][:-1]:  
                        action_batch[i] = state[i][-1]
            return action_batch.long()
        else:
            
            with torch.no_grad():
                qt_index = state[:,:self.env.num_servers].to(self.device)
                data = self.observation_formation(state).to(self.device)
                dis_req = self.env.distance_request(state).to(self.device)
                if self.var_distance == True and self.dir_graph == False:
                    q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req,  edge_weight=data.edge_weight)
                else:
                    q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
                # print(data.x, data.edge_index, data.batch, dis_req)
                q_values = q_values.reshape(self.env.batch_size, -1)
                # create empty tensor C of size NxM
                C = torch.zeros_like(qt_index).to(self.device)

                # loop through each row of qt_index
                for i in range(qt_index.shape[0]):
                    # use row i of qt_index as an index to extract a row from q_values
                    row_b = q_values[i, qt_index[i].long()]
                    # assign the extracted row to the corresponding row in C
                    C[i] = row_b
                    
                max_index = torch.argmax(C, dim =1)
                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))
                # this makes it slow 
                if self.env.request_same_node: 
                    for i in range(self.env.batch_size):
                        if state[i][-1] in state[i][:-1]:  
                            action_batch[i] = state[i][-1]

                if failsafe: 
                    min_distance, min_distance_actions = self.env.get_min_distance(state.to(self.device))
                    condition = min_distance > self.num_layers
                    action_batch.long()[condition] = min_distance_actions.long()[condition]
        
                return action_batch
    
    def update(self):
        
        batch = random.sample(self.memory, self.batch_size)
        concatenated = [torch.cat(tensors, dim=0) for tensors in zip(*batch)]
        if self.constant_probability:
            state_batch, action_batch, reward_batch, next_state_batch = concatenated[0], concatenated[1], concatenated[2], concatenated[3]
            # print(state_batch, action_batch, reward_batch, next_state_batch)
        else: 
            state_batch, action_batch, reward_batch, next_state_batch, node_pbs, node_pbs_next = concatenated[0], concatenated[1], concatenated[2], concatenated[3], concatenated[4], concatenated[5]
            
        if self.constant_probability:
            data = self.observation_formation(state_batch).to(self.device)
        else: 
            data = self.observation_formation(state_batch, node_pbs, constant_probability = False).to(self.device)
        
        
        # print(data.x)

        dis_req = self.env.distance_request(state_batch).to(self.device)    

        if self.var_distance == True and self.dir_graph == False:
            q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req,  edge_weight=data.edge_weight)
        else:
            q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)

        q_values = q_values.reshape(self.env.batch_size, -1)
        q_values = q_values.gather(1, action_batch.long())

        next_qt_index = next_state_batch[:,:self.env.num_servers]
        if self.constant_probability:
            data_next = self.observation_formation(next_state_batch).to(self.device)
        else:
            data_next = self.observation_formation(next_state_batch, node_pbs_next, constant_probability= False).to(self.device) 

        # print(data_next.x)

        dis_req_next = self.env.distance_request(next_state_batch).to(self.device)    
        if self.var_distance == True and self.dir_graph == False:
            next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch, dis_req_next, edge_weight=data_next.edge_weight)
        else:
            next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch, dis_req_next)

        next_q_values = next_q_values.reshape(self.env.batch_size, -1)

        # create empty tensor C of size NxM
        C = torch.zeros_like(next_qt_index).to(self.device)
        # loop through each row of qt_index
        for i in range(next_qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = next_q_values[i, next_qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b

        next_q_values = C.max(1)[0].unsqueeze(1)

        expected_q_values = reward_batch.view(-1, 1) + self.gamma * next_q_values 
        loss = F.mse_loss(q_values, expected_q_values)
        
        # print(q_values.size(), expected_q_values.size())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network
        self.soft_update_target_network()

    def remember(self, state, action, reward, next_state, node_pbs = 1, node_pbs_next = 1):
        if self.constant_probability:
            for i in range(self.batch_size):
                self.memory.append((state[i].unsqueeze(0), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[i].unsqueeze(0)))
        else: 
            for i in range(self.batch_size):
                self.memory.append((state[i].unsqueeze(0), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[i].unsqueeze(0), node_pbs, node_pbs_next))
    

    def soft_update_target_network(self, tau=0.01):

      for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)
    
    def optimize(self, num_steps=200, estimate_steps =20, epsilon_decay = False, explr = 0.6, display_results = False, print_results = False, decay_rate = 0.0005, failsafe = False, save_results = False):

        if save_results:
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}'):  
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/train_curves'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/train_curves')

        file_paths = ['results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}.csv']

        if any(os.path.exists(file_path) for file_path in file_paths):
            print(f'Skipping training, as one or more result files exist.')
        
        else: 
            print(f'Experiment  started')   

            state = self.env.reset()
            
            steps_for_display = int(10000/self.batch_size)
            # steps_for_display = 1
            num_steps = int(num_steps*1000/self.batch_size)
            
            # num_steps = 10

            estimate_steps = int(estimate_steps*1000/self.batch_size)

            if display_results:
                self.run = neptune.init_run(
                    project="iliyasbektas/kserver",
                    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
                )  # your credentials 
                self.run["agent"] = self.class_name
                self.run["num_nodes"] = self.env.num_nodes
                self.run["graph_type"] = self.env.graph_type
                self.run["gamma"] = self.gamma
                self.run["vd"] = self.var_distance
                self.run["td"] = self.tr_dist
                self.run["ta"] = self.tr_att
                self.run["seed"] = self.seed
            
            initial_percentage = explr  
            initial_limit = int(num_steps * initial_percentage)  
            
            # num_sequences = 10
            # requests, state_init = generate_requests(self.env, 42, 4000, num_sequences=num_sequences)

            # # Function to find the most frequent element in a 1D tensor
            # def most_frequent_element(tensor_row):
            #     return torch.mode(tensor_row).values.item()

            # # Apply the function to each row of the 2D tensor
            # most_frequent_elements = [most_frequent_element(row) for row in requests]

            # print("The most frequent elements for each row are:", most_frequent_elements)
            # index_of_largest_value = self.env.probabilities.index(max(self.env.probabilities))
            # print("The index of the largest value is:", index_of_largest_value)
            # # print(env.probabilities)
            # print(state_init)
            tot_params = self.total_params()
            best_estimate = float('-inf')
            start_time = time.time()
            for step in range(num_steps):
                
                if epsilon_decay:
                    epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)     
                else:
                    if step < initial_limit:
                        epsilon = 0.5
                    else: 
                        epsilon = 0.1
                
                # if step == 1:
                #     print(state)

                if self.constant_probability:
                    action = self.get_action(state.to(self.device), epsilon, failsafe = failsafe).to(self.device)
                    # print(state.size())
                    # print(action.size())
                    next_state, reward, _ = self.env.step(action, state.to(self.device))
                    self.remember(state.to(self.device),
                    action,
                    reward.to(self.device),
                    next_state.to(self.device)
                    )
                else: 
                    node_pbs = torch.FloatTensor(self.env.probabilities).view(1, -1).to(self.device)
                    action = self.get_action(state.to(self.device), epsilon, failsafe = failsafe).to(self.device)    
                    next_state, reward, _ = self.env.step(action, state.to(self.device))
                    node_pbs_next = torch.FloatTensor(self.env.probabilities).view(1, -1).to(self.device)     
                    self.remember(state.to(self.device),
                    action,
                    reward.to(self.device),
                    next_state.to(self.device),
                    node_pbs,
                    node_pbs_next
                    )

                # print(reward)
                
                self.update()
                # self.print_network_weights()
                
                state = next_state
                self.total_reward.append(reward)
                # print(self.total_reward)
                # print(len(self.total_reward))
                # print(-estimate_steps*self.batch_size)
                
                # print(len(self.total_reward[(-estimate_steps):]))

        #               start_time = time.time()

        # for i in range(num_requests): 
        #     # Print progress every 10%
        #     if i % progress_interval == 0 and i != 0:
        #         elapsed_time = time.time() - start_time
        #         print(f"{(i / num_requests) * 100:.0f}% of requests finished and it took {elapsed_time:.2f} seconds")
                
                # print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                
                if ((step+1)  % steps_for_display == 0):
                    if print_results:
                        elapsed_time = time.time() - start_time
                        start_time = time.time()
                        step_estimate = self.estimate(estimate_steps)
                        flattened_rewards = torch.cat(self.total_reward).view(-1)
                        # print(flattened_rewards.size())
                        selected_rewards = flattened_rewards[(-estimate_steps * self.batch_size):]
                        # print(selected_rewards.size())
                        # Compute the average reward
                        average_reward = torch.mean(selected_rewards)
                        # average_reward = torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1))
                        step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                        print(f"Step {step1000}, Epsilon {epsilon:.2f}, Average Reward {average_reward:.2f}, Estimate {step_estimate[0]:.2f}, Time Taken: {elapsed_time:.2f} seconds")

                    if display_results:
                        if print_results:
                            self.run["Estimate"].append(step_estimate[0]) 
                            self.run["average_reward"].append(average_reward)
                        else:
                            step_estimate = self.estimate(estimate_steps)
                            average_reward = torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1))
                            self.run["average_reward"].append(average_reward)
                            self.run["Estimate"].append(step_estimate[0]) 
                        # if self.step > int(num_steps * 0.95): 
                        #     self.run.stop()
                    if save_results:
                        if self.gen:
                            pass
                        else: 
                            if print_results == True or display_results ==True:
                                estimate, _, _, raw_result = step_estimate
                            else: 
                                average_reward = torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1))
                                step_estimate = self.estimate(estimate_steps)
                                estimate, _, _, raw_result = step_estimate
                                step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)

                            # print(f"Current estimate: {estimate.item()}, Best estimate: {best_estimate}")
                                
                            if estimate.item() > best_estimate: 
                                torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                                best_estimate = estimate.item()
                            
                            # print(f"Best estimate: {best_estimate}")

                    
                            # save the train curve
                            train_curve = f'results/gen_testing/VD{self.var_distance}/train_curves/train_curve_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}_ar.csv'

                            if not os.path.exists(train_curve):
                                with open(train_curve, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    writer.writerow(['step', 'graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'num_layers', 'VD', 'tr_dist', 'tr_att', 'DG', 'batch_size', 'tot_params','average_reward', 'estimate'])
                                    writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate.item(), 3)]) 
                            else: 
                                with open(train_curve, 'a', newline='') as f:  # Open in append mode
                                    writer = csv.writer(f)
                                    # Append new row to CSV file
                                    writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate.item(), 3)]) 


                            # save the last estimate

                            output_file_name_raw = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}_raw.csv'
                            with open(output_file_name_raw, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    # writer.writerow([raw_result])
                                    writer.writerow(raw_result.tolist())  
                        
            if display_results:
                self.run.stop()
                
            if print_results:
                try: 
                    print(f"Step {self.round_to_nearest_10000((step+1)*self.batch_size)}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                except: 
                    # print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                    print(f"Step {self.round_to_nearest_10000((step+1)*self.batch_size)}, Average Reward {average_reward:.2f}, Estimate {step_estimate[0]:.2f}")
            
        


    def estimate(self, num_steps = 1, episode = False):

        state = self.env.reset()
        
        # num_steps = 10
        reward_list = []
        for step in range(num_steps):
            action = self.get_action(state.to(self.device), epsilon = 0)
            next_state, reward, _ = self.env.step(action, state)
            state = next_state
            reward_list.append(reward.reshape(self.env.batch_size, 1))
             
        estimates = torch.cat(reward_list, dim=1)
        # print('Estimate size:', estimates.size())
        if episode:  
            return torch.mean(estimates, 1), torch.quantile(estimates, 0.25, 1), torch.quantile(estimates, 0.75, 1), estimates
        else: 
            estimates = estimates.view(-1)
            return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
        


    def estimate_seq(self, state, requests):
        total_reward_episode = []

        if self.est_pr_acc: 
            self.pr_acc = torch.zeros(self.env.batch_size, self.env.num_nodes).to(self.env.device)

        num_requests = requests.shape[1]
        progress_interval = num_requests // 10  # 10% of requests
        start_time = time.time()

        for i in range(num_requests): 
            # Print progress every 10%
            if i % progress_interval == 0 and i != 0:
                elapsed_time = time.time() - start_time
                print(f"{(i / num_requests) * 100:.0f}% of requests finished and it took {elapsed_time:.2f} seconds")

            action = self.get_action(state.to(self.device), epsilon = 0)

            if self.est_pr_acc: 
                self.pr_acc[:, int(state[:, -1].item())] += 1

            next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0], 1))
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0], 1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates

        

        # sum_total_reward_episode = torch.sum(torch.cat(total_reward_episode, dim=1), dim=1).reshape(state.shape[0],1)
        # return sum_total_reward_episode
    
    def round_to_nearest_10000(self, number):
        return round(number / 10000) * 10000

    # def estimate_seq_pr_acc(self, state, requests):

    #     total_reward_episode = []

    #     for i in range(requests.shape[1]): 
    #     # for i in range(5):
    #         # print(f'request{i}')
    #         action = self.get_action(state.to(self.device), epsilon = 0)
    #         # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
    #         next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
    #         state = next_state
    #         total_reward_episode.append(reward.reshape(state.shape[0],1))

    #     sum_total_reward_episode = torch.sum(torch.cat(total_reward_episode, dim=1), dim=1).reshape(state.shape[0],1)
    #     return sum_total_reward_episode
    

        




    
    @property
    def class_name(self):
        return self.__class__.__name__
                    




        # output_file_name = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}.csv'
                            
                            
        # with open(output_file_name, 'w', newline='') as f:
        #     writer = csv.writer(f)
        #     writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
        #     writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)]) 
        # # save the curve of estimates
        # output_file_name_wqs = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/train_curve/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_wqs.csv'
        
        # if not os.path.exists(output_file_name_wqs):
        #     with open(output_file_name_wqs, 'w', newline='') as f:
        #         writer = csv.writer(f)
        #         writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
        #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)]) 
        # else: 
        #     with open(output_file_name_wqs, 'a', newline='') as f:  # Open in append mode
        #         writer = csv.writer(f)
        #         # Append new row to CSV file
        #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])
        
        # # save the average reward curve
        # output_file_name_ar = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/train_curve/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_ar.csv'

        # if not os.path.exists(output_file_name_ar):
        #     with open(output_file_name_ar, 'w', newline='') as f:
        #         writer = csv.writer(f)
        #         writer.writerow(['graph_type', 'agent', 'gamma', 'num_nodes', 'seed', 'hidden_channels', 'average_reward'])
        #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(average_reward.item(), 3)]) 
        # else: 
        #     with open(output_file_name_ar, 'a', newline='') as f:  # Open in append mode
        #         writer = csv.writer(f)
        #         # Append new row to CSV file
        #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(average_reward.item(), 3)])