from torch_geometric.nn import GCNConv, global_add_pool
import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import convert
import torch_geometric
from Policies.NET import Net
from Policies.Random_Greedy import GreedyPolicy
import neptune
from KServerEnv import KServerEnv 
import csv
import os


class GCN_RL():

    def __init__(self, env, seed = 42, gamma=0.99, lr=0.001, memory_size=10000, hidden_channels = 128, hidden_channels_node = 32, est_pr_acc = False, shared_weights = True, gen = False, use_batch_norm = False):

        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.gen = gen
        self.env = env
        self.lr = lr
        self.batch_size = self.env.batch_size
        self.device = self.env.device
        self.memory = deque(maxlen=memory_size)
        self.gamma = gamma 
        self.uniform_random = self.env.uniform_random 
        self.constant_probability = self.env.constant_probability    
        self.est_pr_acc = est_pr_acc
        self.hidden_channels = hidden_channels
        if self.uniform_random == True and self.constant_probability== False: 
            raise ValueError("Unfeasible configuration")
        # self.num_layers = int(abs(torch.quantile(GreedyPolicy(KServerEnv(num_nodes=self.env.num_nodes, num_servers = self.env.num_servers, graph_type= self.env.graph_type, device=self.device)).estimate(20)[3], 0.01, interpolation='lower')))
        self.num_layers = 12
        
        if self.uniform_random:
            self.in_channells = 2
        else: 
            self.in_channells = 3
        # Initialize action-value funciton Q
        self.q_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels, out_channels=self.env.num_nodes, num_layers = self.num_layers - 2, shared_weights = shared_weights, use_batch_norm = use_batch_norm).to(self.device)
        # Initialize target action-value function Q'
        self.target_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels, out_channels=self.env.num_nodes, num_layers = self.num_layers - 2, shared_weights = shared_weights,  use_batch_norm = use_batch_norm).to(self.device)
        # every
        self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) 
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        # edge index
        self.edge_index = convert.from_networkx(self.env.graph).edge_index
        # epsilon
        self.max_epsilon=1
        self.min_epsilon=0.05 

       

    
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()
        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")
    
    def total_params(self):
        return sum(p.numel() for p in self.q_network.parameters())
    

        

        


        
    
    def observation_formation(self, state, node_pbs = 1, constant_probability = True): 
        
        if self.uniform_random:
            X = torch.zeros(self.env.batch_size, 2, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                # use row i of qt_index as an index to extract a row from q_values
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
        
        else: 
            X = torch.zeros(self.env.batch_size, 3, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
                if constant_probability: 
                    if self.est_pr_acc: 
                        X[i][2] = self.pr_acc/self.pr_acc.sum()
                    else:
                        X[i][2] = torch.FloatTensor(self.env.probabilities).view(1, -1)
                else:
                    X[i][2] = node_pbs[i]

                if self.env.arrival_rates: 
                    X[i][2] = X[i][2]* self.env.num_nodes

        data_list = []
        for i in range(self.batch_size):
            x = X[i] 
            res = Data(x = x.T, edge_index = self.edge_index)
            data_list.append(res)
        train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
        data = next(iter(train_loader))
        
        return data




    def get_action(self, state, epsilon=0.1, failsafe = False):
        # with probability epsilon 
        if random.random() < epsilon:
            random_indices = torch.randint(low = 0, high=self.env.num_servers, size=(self.batch_size,)).to(self.device)
            action_batch = torch.gather(state[:, :self.env.num_servers], dim=1, index=random_indices.unsqueeze(1))
            if self.env.request_same_node: 
                for i in range(self.env.batch_size):
                    if state[i][-1] in state[i][:-1]:  
                        action_batch[i] = state[i][-1]
            return action_batch.long()
        else:
            with torch.no_grad():

                qt_index = state[:,:self.env.num_servers].to(self.device)
                data = self.observation_formation(state).to(self.device)

                dis_req = self.env.distance_request(state).to(self.device)
                q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
                q_values = q_values.reshape(self.env.batch_size, -1)
                # create empty tensor C of size NxM
                C = torch.zeros_like(qt_index).to(self.device)

                # loop through each row of qt_index
                for i in range(qt_index.shape[0]):
                    # use row i of qt_index as an index to extract a row from q_values
                    row_b = q_values[i, qt_index[i].long()]
                    # assign the extracted row to the corresponding row in C
                    C[i] = row_b
                    
                max_index = torch.argmax(C, dim =1)
                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))

                if self.env.request_same_node: 
                    for i in range(self.env.batch_size):
                        if state[i][-1] in state[i][:-1]:  
                            action_batch[i] = state[i][-1]

                if failsafe: 
                    min_distance, min_distance_actions = self.env.get_min_distance(state.to(self.device))
                    condition = min_distance > self.num_layers
                    action_batch.long()[condition] = min_distance_actions.long()[condition]
                return action_batch
    
    def update(self):
        
        batch = random.sample(self.memory, self.batch_size)
        concatenated = [torch.cat(tensors, dim=0) for tensors in zip(*batch)]
        if self.constant_probability:
            state_batch, action_batch, reward_batch, next_state_batch = concatenated[0], concatenated[1], concatenated[2], concatenated[3]
            # print(state_batch, action_batch, reward_batch, next_state_batch)
        else: 
            state_batch, action_batch, reward_batch, next_state_batch, node_pbs, node_pbs_next = concatenated[0], concatenated[1], concatenated[2], concatenated[3], concatenated[4], concatenated[5]

        

        if self.constant_probability:
            data = self.observation_formation(state_batch).to(self.device)
        else: 
            data = self.observation_formation(state_batch, node_pbs, constant_probability= False).to(self.device)

        # print(data.x)

        dis_req = self.env.distance_request(state_batch).to(self.device)    
        q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
        q_values = q_values.reshape(self.env.batch_size, -1)
        q_values = q_values.gather(1, action_batch.long())

        next_qt_index = next_state_batch[:,:self.env.num_servers]
        if self.constant_probability:
            data_next = self.observation_formation(next_state_batch).to(self.device)
        else:
            data_next = self.observation_formation(next_state_batch, node_pbs_next, constant_probability= False).to(self.device) 

        # print(data_next.x)

        dis_req_next = self.env.distance_request(next_state_batch).to(self.device)    
        next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch, dis_req_next)
        next_q_values = next_q_values.reshape(self.env.batch_size, -1)

        # create empty tensor C of size NxM
        C = torch.zeros_like(next_qt_index).to(self.device)
        # loop through each row of qt_index
        for i in range(next_qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = next_q_values[i, next_qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b

        next_q_values = C.max(1)[0].unsqueeze(1)

        expected_q_values = reward_batch.view(-1, 1) + self.gamma * next_q_values 
        loss = F.mse_loss(q_values, expected_q_values)


        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network
        self.soft_update_target_network()

    def remember(self, state, action, reward, next_state, node_pbs = 1, node_pbs_next = 1):
        if self.constant_probability:
            for i in range(self.batch_size):
                self.memory.append((state[i].unsqueeze(0), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[i].unsqueeze(0)))
        else: 
            for i in range(self.batch_size):
                self.memory.append((state[i].unsqueeze(0), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[i].unsqueeze(0), node_pbs, node_pbs_next))
    

    def soft_update_target_network(self, tau=0.01):

      for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)
    
    def optimize(self, num_steps=200, estimate_steps =20, epsilon_decay = False, explr = 0.6, display_results = False, print_results = False, decay_rate = 0.0005, failsafe = False, save_results = False):

        if save_results:
            if not os.path.exists(f'results/gen_testing/{self.class_name}'):  
                os.makedirs(f'results/gen_testing/{self.class_name}')
            if not os.path.exists(f'results/gen_testing/{self.class_name}/train_results'):
                os.makedirs(f'results/gen_testing/{self.class_name}/train_results')  
            if not os.path.exists(f'results/gen_testing/{self.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/{self.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/{self.class_name}/train_results/raw_results'):
                os.makedirs(f'results/gen_testing/{self.class_name}/train_results/raw_results')

        file_paths = ['results/gen_testing/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}.csv']

        if any(os.path.exists(file_path) for file_path in file_paths):
            print(f'Skipping training, as one or more result files exist.')
        
        else: 
            print(f'Experiment  started')   

            state = self.env.reset()
            steps_for_display = int(10000/self.batch_size)
            # steps_for_display = 1
            num_steps = int(num_steps*1000/self.batch_size)
            # num_steps = 1

            if display_results == True:
                self.run = neptune.init_run(
                    project="iliyasbektas/kserver",
                    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
                )  # your credentials 
                self.run["agent"] = self.class_name
                self.run["num_nodes"] = self.env.num_nodes
                self.run["graph_type"] = self.env.graph_type
                self.run["gamma"] = self.gamma
                self.run["td"] = self.trainable_distance

            

            initial_percentage = explr  
            initial_limit = int(num_steps * initial_percentage)  
            

            for step in range(num_steps):
                if epsilon_decay == True:
                    epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)     
                else:
                    if step < initial_limit:
                        epsilon = 0.5
                    else: 
                        epsilon = 0.1
                if self.constant_probability:
                    action = self.get_action(state.to(self.device), 0, failsafe = failsafe).to(self.device)
                    next_state, reward, _ = self.env.step(action, state.to(self.device))
                
                    self.remember(state.to(self.device),
                    action,
                    reward.to(self.device),
                    next_state.to(self.device),
                    )
                else: 
                    node_pbs = torch.FloatTensor(self.env.probabilities).view(1, -1).to(self.device)
                    action = self.get_action(state.to(self.device), epsilon, failsafe = failsafe).to(self.device)
                    next_state, reward, _ = self.env.step(action, state.to(self.device))
                    node_pbs_next = torch.FloatTensor(self.env.probabilities).view(1, -1).to(self.device)
                
                    self.remember(state.to(self.device),
                    action,
                    reward.to(self.device),
                    next_state.to(self.device),
                    node_pbs,
                    node_pbs_next
                    )
                
                
                self.update()
                
                state = next_state
                self.total_reward = torch.cat((self.total_reward, reward), 0)
                # print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")
                if print_results == True:
                    if ((step+1)  % steps_for_display == 0) == True:
                        step_estimate = self.estimate(estimate_steps)
                        print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {step_estimate[0]:.2f}")
                if display_results == True:
                    if ((step+1)  % steps_for_display == 0) == True:
                        self.run["Average_Reward"].append(torch.mean(self.total_reward[-steps_for_display:]))
                        try: 
                            self.run["Estimate"].append(step_estimate[0]) 
                        except: 
                            step_estimate = self.estimate(estimate_steps)
                            self.run["Estimate"].append(step_estimate[0]) 
                        # if self.step > int(num_steps * 0.95): 
                        #     self.run.stop()
                if save_results:
                    if self.gen:
                        pass
                    else: 
                        if ((step+1)  % steps_for_display == 0) == True: 
                            try: 
                                estimate, q1, q3, raw_result = step_estimate
                            except: 
                                step_estimate = self.estimate(estimate_steps)
                                estimate, q1, q3, raw_result = step_estimate
                                
                            output_file_name = f'results/gen_testing/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_td{self.trainable_distance}.csv'
                            torch.save(self.q_network.state_dict(), f'results/gen_testing/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_td{self.trainable_distance}.pth')

                            with open(output_file_name, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'hidden_channels', 'estimate', 'q1', 'q3'])
                                    writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, self.hidden_channels, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)]) 

                            output_file_name_raw = f'results/gen_testing/{self.class_name}/train_results/raw_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_hidden_channels{self.hidden_channels}__gamma{self.gamma}_td{self.trainable_distance}_raw.csv'
                            with open(output_file_name_raw, 'w', newline='') as f:
                                    writer = csv.writer(f)
                                    # writer.writerow([raw_result])
                                    writer.writerow(raw_result.tolist())  
                        
            if display_results:
                self.run.stop()
                
            if print_results:
                try: 
                    print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {step_estimate[0]:.2f}")
                except: 
                    print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(estimate_steps)[0]:.2f}")

            
        


    def estimate(self, num_steps = 1, episode = False):

        state = self.env.reset()
        num_steps = int(num_steps*1000/self.batch_size)
        reward_list = []
        for step in range(num_steps):
            action = self.get_action(state.to(self.device), epsilon = 0)
            next_state, reward, _ = self.env.step(action, state)
            state = next_state
            if episode: 
                reward_list.append(reward.reshape(self.env.batch_size, 1))
            else: 
                self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
        if episode: 
            estimates = torch.cat(reward_list, dim=1)
            return torch.mean(estimates, 1), torch.quantile(estimates, 0.25, 1), torch.quantile(estimates, 0.75, 1), estimates
        else: 
            estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
            return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
        
    def estimate_seq(self, state, requests):

        total_reward_episode = []

        if self.est_pr_acc: 
            self.pr_acc = torch.zeros(self.env.batch_size, self.env.num_nodes).to(self.env.device)

        for i in range(requests.shape[1]): 
        # for i in range(5):
            # print(f'request{i}')
          
            action = self.get_action(state.to(self.device), epsilon = 0)
            if self.est_pr_acc: 
                self.pr_acc[:, int(state[:, -1].item())] += 1
            # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
            # print(self.pr_acc)
            next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0],1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
        

        # sum_total_reward_episode = torch.sum(torch.cat(total_reward_episode, dim=1), dim=1).reshape(state.shape[0],1)
        # return sum_total_reward_episode
    

    # def estimate_seq_pr_acc(self, state, requests):

    #     total_reward_episode = []

    #     for i in range(requests.shape[1]): 
    #     # for i in range(5):
    #         # print(f'request{i}')
    #         action = self.get_action(state.to(self.device), epsilon = 0)
    #         # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
    #         next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
    #         state = next_state
    #         total_reward_episode.append(reward.reshape(state.shape[0],1))

    #     sum_total_reward_episode = torch.sum(torch.cat(total_reward_episode, dim=1), dim=1).reshape(state.shape[0],1)
    #     return sum_total_reward_episode
    

        




    
    @property
    def class_name(self):
        return self.__class__.__name__
                    
                            