import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import convert
import torch_geometric
from Policies.NET import Net
import neptune
import csv
import os
import torch.nn.functional as F
import torch.nn as nn
import json
import re
from torch_geometric.utils import to_dense_adj
from Policies.GraphDataset import MyData


class GCN_RL_SL():

    def __init__(self, env, seed = 42, gamma=0.99, lr=0.001, num_layers = 12, memory_size=10000, hidden_channels = 128, hidden_channels_node = 32, est_pr_acc = False, shared_weights = True, gen = False, use_batch_norm = False, tr_dist = False, tr_att = False, dir_graph = False):

        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.gen = gen
        self.env = env
        self.lr = lr
        self.batch_size = self.env.batch_size
        self.device = self.env.device
        self.memory = deque(maxlen=memory_size)
        self.gamma = gamma 
        self.uniform_random = self.env.uniform_random 
        self.constant_probability = self.env.constant_probability    
        self.var_distance = self.env.var_distance 
        self.tr_dist = tr_dist
        self.tr_att = tr_att
        self.dir_graph = dir_graph
        if self.var_distance == False and self.tr_dist == True: 
            raise ValueError("Unfeasible configuration")
        if self.tr_dist == False and self.tr_att == True: 
            raise ValueError("Unfeasible configuration")
        # self.var_distance = False 
        self.est_pr_acc = est_pr_acc
        self.hidden_channels = hidden_channels

        if self.uniform_random == True and self.constant_probability== False: 
            raise ValueError("Unfeasible configuration")
        # self.num_layers = int(abs(torch.quantile(GreedyPolicy(KServerEnv(num_nodes=self.env.num_nodes, num_servers = self.env.num_servers, graph_type= self.env.graph_type, device=self.device)).estimate(20)[3], 0.01, interpolation='lower')))
        self.num_layers = num_layers

        if self.uniform_random:
            self.in_channells = 2
        else: 
            self.in_channells = 3
        # Initialize action-value funciton Q
        self.q_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels,  num_layers = self.num_layers,  num_nodes = self.env.num_nodes, shared_weights = shared_weights, use_batch_norm = use_batch_norm, var_distance = self.var_distance, tr_dist = tr_dist, tr_att = tr_att, dir_graph = self.dir_graph).to(self.device)
        model_path = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth'
        if os.path.exists(model_path):
            self.q_network.load_state_dict(torch.load(model_path))
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) 
        # total rewards for separate sets 
        self.total_reward = []
        # edge index
        self.edge_index = convert.from_networkx(self.env.graph).edge_index
        # load Q_table 
        if '_' in self.env.graph_type:
            graph_group = re.match(r'^\D+(?=_)', self.env.graph_type).group(0)
        else: 
            graph_group = self.env.graph_type
        # path = f"data/qtable_wql/{graph_group}/{self.env.graph_type}_{self.env.num_nodes}_{self.env.num_servers}"
        path = f"/data/home/ifb5104/K_server_RL/data/qtable_wql/var_distance{self.env.var_distance}/{graph_group}/{self.env.graph_type}_{self.env.num_nodes}_{self.env.num_servers}"
        with open(os.path.join(path, 'locations_to_id.json'), 'r') as f:
                locations_to_id_str_keys = json.load(f)                           
        self.locations_to_id = {(eval(key)): value for key, value in locations_to_id_str_keys.items()}
        self.Q_policy = torch.load(os.path.join(path, 'Q_policy.pth')).to(self.device) 
        self.Q_table = torch.load(os.path.join(path, 'Q_table.pth')).to(self.device) 
        self.criterion = nn.CrossEntropyLoss()
        
    def print_network_weights(self):
        
        network_weights = self.q_network.state_dict()
        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")
    
    def total_params(self):
        return sum(p.numel
        () for p in self.q_network.parameters())
        
    def observation_formation(self, state, node_pbs = 1, constant_probability = True): 
        
        if self.uniform_random:
            X = torch.zeros(self.env.batch_size, 2, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                # use row i of qt_index as an index to extract a row from q_values
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
        
        else: 
            X = torch.zeros(self.env.batch_size, 3, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
                if constant_probability: 
                    if self.est_pr_acc: 
                        X[i][2] = self.pr_acc/self.pr_acc.sum()
                    else:
                        X[i][2] = torch.FloatTensor(self.env.probabilities).view(1, -1)
                else:
                    X[i][2] = node_pbs[i]

                if self.env.arrival_rates: 
                    X[i][2] = X[i][2]* self.env.num_nodes
        if self.dir_graph:

            X = torch.transpose(X, -1, -2)
            edge_index = to_dense_adj(self.edge_index)
            edge_index = edge_index.repeat(self.batch_size, 1, 1)
            data = MyData(X, edge_index)

            return data 
        else:
            if self.var_distance: 
                # Flatten the cost matrix into a 1D tensor
                cost_vector = self.env.cost_matrix.view(-1)

                # Compute the linear indices corresponding to the edges in the edge_index
                linear_indices = self.edge_index[0] * self.env.cost_matrix.size(1) + self.edge_index[1]

                # Extract edge weights from the cost vector using linear indices
                edge_weight = cost_vector[linear_indices]
                # edge_weight = 1/self.env.cost_matrix[self.edge_index[0], self.edge_index[1]]
            
            data_list = []
            for i in range(self.batch_size):
                x = X[i] 
                if self.var_distance: 
                    res = Data(x = x.T, edge_index = self.edge_index, edge_weight = edge_weight)
                else: 
                    res = Data(x = x.T, edge_index = self.edge_index)
                data_list.append(res)
            train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
            data = next(iter(train_loader))

            return data




    def get_action(self, state, failsafe = False):
        # with probability epsilon
        # with torch.no_grad(): 

                qt_index = state[:,:self.env.num_servers].to(self.device)
                data = self.observation_formation(state).to(self.device)

                dis_req = self.env.distance_request(state).to(self.device)
                if self.var_distance:
                    q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req,  edge_weight=data.edge_weight)
                else:
                    q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
                q_values = q_values.reshape(self.env.batch_size, -1)
                # create empty tensor C of size NxM
                C = torch.zeros_like(qt_index).to(self.device)

                # loop through each row of qt_index
                for i in range(qt_index.shape[0]):
                    # use row i of qt_index as an index to extract a row from q_values
                    row_b = q_values[i, qt_index[i].long()]
                    # assign the extracted row to the corresponding row in C
                    C[i] = row_b

                max_index = torch.argmax(C, dim =1)
                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))

                if self.env.request_same_node: 
                    for i in range(self.env.batch_size):
                        if state[i][-1] in state[i][:-1]:  
                            action_batch[i] = state[i][-1]
                            # index_ = (qt_index[i] == state[i][-1]).nonzero().item()
                            # C_upd = C[i].clone()
                            # index_of_max = torch.argmax(C_upd).item()
                            # max_val = C_upd[index_of_max].clone()
                            # C_upd[index_of_max] = C_upd[index_]
                            # C_upd[index_] = max_val
                            # C[i] = C_upd

                result = torch.zeros_like(q_values)  
                for i, row in enumerate(qt_index):
                    result[i, row.long()] = q_values[i, row.long()]              
                
                # q_values = q_values.gather(1, action_batch.long())

                if failsafe: 
                    min_distance, min_distance_actions = self.env.get_min_distance(state.to(self.device))
                    condition = min_distance > self.num_layers
                    action_batch.long()[condition] = min_distance_actions.long()[condition]
                
                
                # one_hot_encoded = torch.zeros(self.batch_size, self.env.num_nodes)
                # one_hot_encoded.scatter_(1, action_batch.long(), 1)
                # one_hot_encoded.requires_grad = True 
                # preds = F.softmax(one_hot_encoded, dim=1) 
                # preds = F.softmax(C, dim=1) 
                
                return action_batch, result

    def get_target_values(self, state):
    
        with torch.no_grad():
            qt_index = state[:,:self.env.num_servers].to(self.device)
            server_location_batch = state[:,self.env.num_servers].to(self.device)
            qt_index_tuple = tuple(map(tuple, qt_index.int().tolist()))
            q_id_tensor = torch.tensor([self.locations_to_id[index] for index in qt_index_tuple], dtype=torch.long)
            server_location_batch = server_location_batch.long()
            target_q_values = self.Q_table[q_id_tensor, server_location_batch, :].to(self.device)

            target_q_values_full = torch.zeros(self.batch_size, self.env.num_nodes, dtype=target_q_values.dtype).to(self.device) 
            for i in range(self.batch_size):
                target_q_values_full[i, qt_index[i].long()] = target_q_values[i]
        
            # target_q_values, _ = torch.max(target_q_values, dim=1)
            return target_q_values_full


            # gt_lst = self.Q_policy[q_id_tensor, server_location_batch]
            # concatenated_gt = gt_lst.to(self.device)
            # # shuffled_tensor = torch.randperm(concatenated_gt.size(0))
            # # tensor_shuffled = concatenated_gt[shuffled_tensor]
            # # return tensor_shuffled
            # return concatenated_gt
    
    def optimize(self, num_steps=200, estimate_steps =20, explr = 0.6, display_results = False, print_results = False, decay_rate = 0.0005, failsafe = False, save_results = False):

        if save_results:
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}'):  
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/train_curve'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/train_curve')

        file_paths = ['results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}.csv']

        if any(os.path.exists(file_path) for file_path in file_paths):
            print(f'Skipping training, as one or more result files exist.')
        
        else: 
            print(f'Experiment  started')   
            
            state = self.env.reset()
            steps_for_display = int(10000/self.batch_size)
            # steps_for_display = 1
            num_steps = int(num_steps*1000/self.batch_size)
            # num_steps = 10
            estimate_steps = int(estimate_steps*1000/self.batch_size)
            # num_steps = 10

            if display_results:
                self.run = neptune.init_run(
                    project="iliyasbektas/kserver",
                    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
                )  # your credentials 
                self.run["agent"] = self.class_name
                self.run["num_nodes"] = self.env.num_nodes
                self.run["graph_type"] = self.env.graph_type
                self.run["gamma"] = self.gamma
                self.run["vd"] = self.var_distance
                self.run["td"] = self.tr_dist
                self.run["ta"] = self.tr_att
                self.run["lr"] = self.lr
            
            tot_params = self.total_params()
            best_estimate = float('-inf')
            # lowest_loss = float('+inf')


            for step in range(num_steps):
                action, preds = self.get_action(state.to(self.device),failsafe = failsafe)
                next_state, reward, _ = self.env.step(action, state.to(self.device))
                target_q_values = self.get_target_values(state.to(self.device)) 
                # print(preds.float(), target_q_values)
                loss = F.mse_loss(preds, target_q_values.float())
                # concatenated_gt = self.get_target_values(state.to(self.device)) 
                # print(preds, concatenated_gt.long())
                # preds =  F.softmax(preds, dim=1)
                # loss = self.criterion(preds, concatenated_gt.long())
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()         
                state = next_state



                self.total_reward.append(reward)
                # print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                if print_results:
                    if ((step+1)  % steps_for_display == 0):
                        step_estimate = self.estimate(estimate_steps)
                        # print(f"Total Reward Size: {self.total_reward.size()}")
                        # print(f"Total Reward for Display Size: {self.total_reward[(-estimate_steps*self.batch_size):].size()}")
                        # print((-estimate_steps*self.batch_size))
                        # print(torch.cat(self.total_reward).view(-1)[(-estimate_steps*self.batch_size):].size(), step_estimate[-1].size())
                        step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                        average_reward = torch.mean(torch.cat(self.total_reward).view(-1)[(-estimate_steps*self.batch_size):])
                        print(f"Step {step1000},  Average Reward {average_reward:.2f}, Estimate {step_estimate[0]:.2f}")
                        print("Loss:", round(loss.item(), 3))
                        # print(preds.float(), target_q_values)
                        # self.print_network_weights()
                if display_results:
                    if ((step+1)  % steps_for_display == 0):
                        self.run["average_reward"].append(torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1)))
                        if print_results:
                            self.run["Estimate"].append(step_estimate[0]) 
                        else:
                            step_estimate = self.estimate(estimate_steps)
                            self.run["Estimate"].append(step_estimate[0]) 
                        # if self.step > int(num_steps * 0.95): 
                        #     self.run.stop()
                if save_results:
                    if self.gen:
                        pass
                    else: 
                        if ((step+1)  % steps_for_display == 0): 
                            if print_results == True or display_results ==True:
                                estimate, q1, q3, raw_result = step_estimate
                            else: 
                                step_estimate = self.estimate(estimate_steps)
                                estimate, q1, q3, raw_result = step_estimate
                                step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                                average_reward = torch.mean(torch.cat(self.total_reward).view(-1)[(-estimate_steps*self.batch_size):])
                                
                            # if loss.item() < lowest_loss: 
                            if estimate.item() > best_estimate:
                                torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.pth')
                                # lowest_loss = loss.item()
                                best_estimate = estimate.item()
                            # save the train curve
                            train_curve = f'results/gen_testing/VD{self.var_distance}/train_curves/train_curve_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_seed{self.seed}_DG{self.dir_graph}_bs{self.batch_size}_lr{self.lr}_nl{self.num_layers}.csv'

                            if not os.path.exists(train_curve):
                                with open(train_curve, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    writer.writerow(['step', 'graph_type', 'agent', 'gamma', 'lr', 'num_nodes', 'num_servers', 'seed', 'hidden_channels', 'num_layers', 'VD', 'tr_dist', 'tr_att', 'DG', 'batch_size', 'tot_params','average_reward', 'estimate'])
                                    writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate.item(), 3)]) 
                            else: 
                                with open(train_curve, 'a', newline='') as f:  # Open in append mode
                                    writer = csv.writer(f)
                                    # Append new row to CSV file
                                    writer.writerow([step1000, self.env.graph_type, self.class_name, self.gamma, self.lr, self.env.num_nodes, self.env.num_servers, self.seed, self.hidden_channels, self.num_layers, self.var_distance, self.tr_dist, self.tr_att, self.dir_graph, self.batch_size, tot_params, round(average_reward.item(), 3), round(estimate.item(), 3)]) 

                            output_file_name_raw = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_ta{self.tr_att}_lr{self.lr}_nl{self.num_layers}_raw.csv'
                            with open(output_file_name_raw, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    # writer.writerow([raw_result])
                                    writer.writerow(raw_result.tolist())  
                        
            if display_results:
                self.run.stop()
                
            if print_results:
                try: 
                    print(f"Step {self.round_to_nearest_10000((step+1)*self.batch_size)},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                except: 
                    # print(f"Step {self.round_to_nearest_10000((step+1)*self.batch_size)}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                    print(f"Step {self.round_to_nearest_10000((step+1)*self.batch_size)}, Average Reward {torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1)):.2f}, Estimate {step_estimate[0]:.2f}")
        
        

    def estimate(self, num_steps = 1, episode = False):

        state = self.env.reset()
        
        reward_list = []
        for step in range(num_steps):
            action, _ = self.get_action(state.to(self.device))
            next_state, reward, _ = self.env.step(action, state)
            state = next_state
            reward_list.append(reward.reshape(self.env.batch_size, 1))
             
        estimates = torch.cat(reward_list, dim=1)
        if episode:  
            return torch.mean(estimates, 1), torch.quantile(estimates, 0.25, 1), torch.quantile(estimates, 0.75, 1), estimates
        else: 
            estimates = estimates.view(-1)
            return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
        

    
        
    def estimate_seq(self, state, requests):

        total_reward_episode = []

        if self.est_pr_acc: 
            self.pr_acc = torch.zeros(self.env.batch_size, self.env.num_nodes).to(self.env.device)

        for i in range(requests.shape[1]): 
        # for i in range(5):
            # print(f'request{i}')
          
            action, _ = self.get_action(state.to(self.device))
            if self.est_pr_acc: 
                self.pr_acc[:, int(state[:, -1].item())] += 1
            # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
            # print(self.pr_acc)
            next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0],1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
        

    
    def round_to_nearest_10000(self, number):
        return round(number / 10000) * 10000

        




    
    @property
    def class_name(self):
        return self.__class__.__name__
                    

# def estimate(self, num_steps = 1, episode = False):

    #     state = self.env.reset()
    #     num_steps = int(num_steps*1000/self.batch_size)
    #     # num_steps = 10
    #     reward_list = []
    #     for step in range(num_steps):
    #         action = self.get_action(state.to(self.device))
    #         next_state, reward, _ = self.env.step(action, state)
    #         state = next_state
    #         if episode: 
    #             reward_list.append(reward.reshape(self.env.batch_size, 1))
    #         else: 
    #             self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
    #     if episode: 
    #         estimates = torch.cat(reward_list, dim=1)
    #         return torch.mean(estimates, 1), torch.quantile(estimates, 0.25, 1), torch.quantile(estimates, 0.75, 1), estimates
    #     else: 
    #         estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
    #         print(estimates.size())
    #         return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates

            
            
            # gt_lst = []

            # for i in range(qt_index.shape[0]):           
            #     q_id = self.locations_to_id[tuple(qt_index[i].squeeze().tolist())]
            #     gt = self.Q_policy[q_id, int(server_location_batch[i])]
            #     # gt = qt_index[i][action_index_gt]
            #     gt_lst.append(gt)
                
            # concatenated_gt = torch.stack(gt_lst, dim=0).to(self.device)

            # return concatenated_gt


    # sum_total_reward_episode = torch.sum(torch.cat(total_reward_episode, dim=1), dim=1).reshape(state.shape[0],1)
        # return sum_total_reward_episode
    

    # def estimate_seq_pr_acc(self, state, requests):

    #     total_reward_episode = []

    #     for i in range(requests.shape[1]): 
    #     # for i in range(5):
    #         # print(f'request{i}')
    #         action = self.get_action(state.to(self.device))
    #         # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
    #         next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
    #         state = next_state
    #         total_reward_episode.append(reward.reshape(state.shape[0],1))

    #     sum_total_reward_episode = torch.sum(torch.cat(total_reward_episode, dim=1), dim=1).reshape(state.shape[0],1)
    #     return sum_total_reward_episode






            
                            # output_file_name = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}.csv'
                            
                            # with open(output_file_name, 'w', newline='') as f:
                            #     writer = csv.writer(f)
                            #     writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
                            #     writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)]) 
                            
                            # output_file_name_wqs = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/train_curve/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_vd{self.var_distance}_td{self.tr_dist}_wqs.csv'
                            
                            # if not os.path.exists(output_file_name_wqs):
                            #     with open(output_file_name_wqs, 'w', newline='') as f:
                            #             writer = csv.writer(f)
                            #             writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
                            #             writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)]) 
                            # else: 
                            #     with open(output_file_name_wqs, 'a', newline='') as f:  # Open in append mode
                            #         writer = csv.writer(f)
                            #         # Append new row to CSV file
                            #         writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])

                    