from torch_geometric.nn import GCNConv, global_add_pool
import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import convert
import torch_geometric
from Policies.NET import Net
from Policies.Random_Greedy import GreedyPolicy
import neptune
from KServerEnv import KServerEnv 


class GCN_RL_GEN():

    def __init__(self, env, seed = 42, gamma=0.99, lr=0.001, memory_size=10000, hidden_channels = 128, shared_weights = True):

        
        
        seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.env = env
        if self.env.general_model == False:
            raise ValueError("The environmet general_model argument must be True")
        self.lr = lr
        self.batch_size = self.env.batch_size
        self.device = self.env.device
        self.memory = deque(maxlen=memory_size)
        self.uniform_random = self.env.uniform_random 
        self.constant_probability = self.env.constant_probability
        # self.num_layers = int(abs(torch.quantile(GreedyPolicy(KServerEnv(num_nodes=self.env.num_nodes, num_servers = self.env.num_servers, graph_type = 'line', device=self.device)).estimate(20)[3], 0.1, interpolation='lower')))
        self.num_layers = 12
        self.gamma = gamma 
        # Initialize action-value funciton Q
        if self.uniform_random:
            self.in_channells = 2
        else: 
            self.in_channells = 3
        self.q_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels, out_channels=self.env.num_nodes, num_layers = (self.num_layers - 2), shared_weights = shared_weights).to(self.device)
        # Initialize target action-value function Q'
        self.target_network = Net(in_channels=self.in_channells, hidden_channels = hidden_channels, out_channels=self.env.num_nodes, num_layers = (self.num_layers - 2), shared_weights = shared_weights).to(self.device)
        # every
        self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr) 
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        

    
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()
        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")
    
    def total_params(self):
        return sum(p.numel() for p in self.q_network.parameters())
        
    
    def observation_formation(self, state_batch, node_pbs = 1,  constant_probability = True): 
        state, graph_indices = state_batch
        if self.uniform_random: 
            X = torch.zeros(self.env.batch_size, 2, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                # use row i of qt_index as an index to extract a row from q_values
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
        
        else: 
            X = torch.zeros(self.env.batch_size, 3, self.env.num_nodes).to(self.device)
            for i in range(self.env.batch_size):
                X[i][0][state[i][:-1].long()] = 1
                X[i][1][state[i][-1].long()] = 1
                if constant_probability:
                    X[i][2] = self.env.graph_pbs[graph_indices[i]].view(1, -1)
                else: 
                    X[i][2] = node_pbs[i]

                # else:  
                #     X[i][2] = node_pbs[i] = self.env.graph_pbs[graph_indices[i]].view(1, -1)

                # node_pbs = torch.zeros(self.env.batch_size, self.env.num_nodes).to(self.device)
                # for i in range(self.env.batch_size):
                #     node_pbs[i] = self.env.graph_pbs[graph_indices[i]].view(1, -1)



        data_list = []
        for i in range(self.batch_size):
          x = X[i]
          res = Data(x = x.T, edge_index = self.env.graph_edge_indices[graph_indices[i].item()]) # here we have to feed a particular edge_index
          data_list.append(res)
        train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
        data = next(iter(train_loader)) 
        return data

    def node_pbs_formation(self, state_batch): 
        _, graph_indices = state_batch
        node_pbs = torch.zeros(self.env.batch_size, self.env.num_nodes).to(self.device)
        for i in range(self.env.batch_size):
            node_pbs[i] = self.env.graph_pbs[graph_indices[i]].view(1, -1)

        return node_pbs


    def get_action(self, state_batch, epsilon=0.1):
        # with probability epsilon 
        if random.random() < epsilon:
            state = state_batch[0].to(self.device)
            random_indices = torch.randint(low = 0, high=self.env.num_servers, size=(self.batch_size,)).to(self.device)
            action_batch = torch.gather(state[:, :self.env.num_servers], dim=1, index=random_indices.unsqueeze(1))
            return action_batch.long()
        else:
            with torch.no_grad():

                
                data = self.observation_formation(state_batch).to(self.device)    
                qt_index = state_batch[0][:,:self.env.num_servers].to(self.device)
                dis_req = self.env.distance_request(state_batch).to(self.device)
                q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
                q_values = q_values.reshape(self.env.batch_size, -1)

                # data = self.observation_formation(state_batch).to(self.device)
                # qt_index = state_batch[0][:,:self.env.num_servers].to(self.device)
                # q_values = self.q_network(data.x, data.edge_index, data.batch)
                # q_values = q_values.reshape(self.env.batch_size, -1)

                # create empty tensor C of size NxM
                C = torch.zeros_like(qt_index).to(self.device)

                # loop through each row of qt_index
                for i in range(qt_index.shape[0]):
                    # use row i of qt_index as an index to extract a row from q_values
                    row_b = q_values[i, qt_index[i].long()]
                    # assign the extracted row to the corresponding row in C
                    C[i] = row_b
                    
                max_index = torch.argmax(C, dim =1)
                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))
                return action_batch
    
    
    
    def update(self):
        

        batch = random.sample(self.memory, self.batch_size)
        concatenated = [torch.cat(tensors, dim=0) for tensors in zip(*batch)]
        if self.constant_probability:
            state, action_batch, reward_batch, next_state, graph_indices = concatenated[0], concatenated[1], concatenated[2], concatenated[3], concatenated[4]
            state_batch = state, graph_indices
            next_state_batch = next_state, graph_indices
        else: 
            state, action_batch, reward_batch, next_state, graph_indices, node_pbs, node_pbs_next = concatenated[0], concatenated[1], concatenated[2], concatenated[3], concatenated[4], concatenated[5], concatenated[6]
            state_batch = state, graph_indices
            next_state_batch = next_state, graph_indices

        if self.constant_probability: 
            data = self.observation_formation(state_batch).to(self.device)
        else: 
            data = self.observation_formation(state_batch, node_pbs, constant_probability= False).to(self.device)

        dis_req = self.env.distance_request(state_batch).to(self.device)
        q_values = self.q_network(data.x, data.edge_index, data.batch, dis_req)
        q_values = q_values.reshape(self.env.batch_size, -1)
        q_values = q_values.gather(1, action_batch.long())

        next_qt_index = next_state_batch[0][:,:self.env.num_servers]
        # data_next = self.observation_formation(next_state_batch).to(self.device)
        if self.constant_probability:
            data_next = self.observation_formation(next_state_batch).to(self.device)
        else:
            data_next = self.observation_formation(next_state_batch, node_pbs_next, constant_probability= False).to(self.device) 
        dis_req_next = self.env.distance_request(next_state_batch).to(self.device)
        next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch, dis_req_next)
        next_q_values = next_q_values.reshape(self.env.batch_size, -1)



        # data = self.observation_formation(state_batch).to(self.device)
        # q_values = self.q_network(data.x, data.edge_index, data.batch)
        # q_values = q_values.reshape(self.env.batch_size, -1)
        # q_values = q_values.gather(1, action_batch.long())


        # next_qt_index = next_state_batch[0][:,:self.env.num_servers]
        # data_next = self.observation_formation(next_state_batch).to(self.device)
        # next_q_values = self.target_network(data_next.x, data_next.edge_index, data_next.batch)
        # next_q_values = next_q_values.reshape(self.env.batch_size, -1)

        # create empty tensor C of size NxM
        C = torch.zeros_like(next_qt_index).to(self.device)
        # loop through each row of qt_index
        for i in range(next_qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = next_q_values[i, next_qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b

        next_q_values = C.max(1)[0].unsqueeze(1)

        expected_q_values = reward_batch.view(-1, 1) + self.gamma * next_q_values 
        loss = F.mse_loss(q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network
        self.soft_update_target_network()

    def remember(self, state, action, reward, next_state, node_pbs = 1, node_pbs_next = 1):
    
        
        if self.constant_probability:
            for i in range(self.batch_size):
                self.memory.append((state[0][i].unsqueeze(0).to(self.device), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[0][i].unsqueeze(0).to(self.device), state[1][i].unsqueeze(0).to(self.device)))
        else: 
            for i in range(self.batch_size):
                self.memory.append((state[0][i].unsqueeze(0).to(self.device), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[0][i].unsqueeze(0).to(self.device), state[1][i].unsqueeze(0).to(self.device), node_pbs[i].unsqueeze(0), node_pbs_next[i].unsqueeze(0)))
    

    def soft_update_target_network(self, tau=0.01):

      for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)


  
    
    
    def optimize(self, num_steps=200, epsilon_decay = False, explr = 0.6, display_results = False, print_results = False, decay_rate = 0.0005):
        # state has to contain different graph with different sizes 
        # or the batch size can be equal to 1 to overcome this bottleneck 
        # there has to be an info about graph type as well 
        state = self.env.reset()
        steps_for_display = int(10000/self.batch_size)
        num_steps = int(num_steps*1000/self.batch_size)

        initial_percentage = explr  
        initial_limit = int(num_steps * initial_percentage)  


        if display_results == True:
            self.run = neptune.init_run(
                project="iliyasbektas/kserver",
                api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
            )  # your credentials 
            self.run["agent"] = self.class_name
            self.run["num_nodes"] = self.env.num_nodes
            self.run["graph_type"] = self.env.graph_type
            self.run["gamma"] = self.gamma

        print('Starting Training....')   
        for step in range(num_steps):
            if epsilon_decay:
                epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
            else:
                if step < initial_limit:
                    epsilon = 0.5
                else: 
                    epsilon = 0.1

            # we need a graph type here due to edge index 
            
            if self.constant_probability:
                action = self.get_action(state, epsilon).to(self.device)
                next_state, reward, _ = self.env.step(action, state)

                self.remember(state,
                action.to(self.device),
                reward.to(self.device),
                next_state
                )
            else: 
                node_pbs = self.node_pbs_formation(state)
                action = self.get_action(state, epsilon).to(self.device)
                next_state, reward, _ = self.env.step(action, state)
                node_pbs_next = self.node_pbs_formation(state)
                self.remember(state,
                action.to(self.device),
                reward.to(self.device),
                next_state, 
                node_pbs,
                node_pbs_next
                )

            
            self.update()
            
            state = next_state
            self.total_reward = torch.cat((self.total_reward, reward), 0)
            # print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")
            # self.estimate(40)
            
            if print_results == True:
                if ((step+1)  % steps_for_display == 0) == True:
                    print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate all {self.estimate_all(40)[0]:.2f}")
                    self.estimate(40, print_results = print_results)

            # if print_results == True:
            #     print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate all {self.estimate_all(40)[0]:.2f}")
            #     self.estimate(40, print_results = print_results)
        #     if print_results == True:
        #         if ((step+1)  % steps_for_display == 0) == True:
        #             step_estimate = self.estimate(40)[0]
                    # print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {step_estimate:.2f}")

        #     if display_results == True:
        #         if ((step+1)  % steps_for_display == 0) == True:
        #             self.run["Average_Reward"].append(torch.mean(self.total_reward[self.env.num_servers:]))
        #             try: 
        #                 self.run["Estimate"].append(step_estimate) 
        #             except: 
        #                 self.run["Estimate"].append(self.estimate(40)[0]) 
        
        # if display_results == True:
        #     self.run.stop()
        # if print_results == True:
        #     print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)}")

        

    
    

    def estimate(self, num_steps = 1, print_results = False):

            state = self.env.reset()
            num_steps = int(num_steps*1000/self.batch_size)
            result_dict = {}

            for step in range(num_steps):
                action = self.get_action(state, 0)
                next_state, reward, _ = self.env.step(action, state)

                A = reward.view(-1, 1)
                B = state[1]



                for index in range(len(self.env.graph_types)):
                    elements = A[B == index]
                    if self.env.graph_types[index] not in result_dict:
                        result_dict[self.env.graph_types[index]] = [elements]
                    else:
                        result_dict[self.env.graph_types[index]].append(elements)

                
                state = next_state
            
            averages_dict = {}
            
            for index, elements_list in result_dict.items():
                combined_tensor = torch.cat(elements_list, dim=0)
                average_tensor = torch.mean(combined_tensor, dim=0)
                averages_dict[index] = average_tensor
            if print_results:
                for index, average_tensor in averages_dict.items():    
                    print(index, round(average_tensor.item(), 3))

            else: 
                return averages_dict
 




            # result_dict = {}

            # for step in range(num_steps):
            #     action = self.get_action(state, 0)
            #     next_state, reward, _ = self.env.step(action, state)

            #     A = reward.view(-1, 1)
            #     B = state[1]

            #     for index in range(len(self.env.grah_types)):
            #         elements = A[B == index]
            #         if index.item() not in result_dict:
            #             result_dict[index.item()] = [elements]
            #         else:
            #             result_dict[index.item()].append(elements)
                
            #     state = next_state

            #     # for index in range(len(self.env.graph_types)):
            #     #     elements = A[B == index]
            #     #     result_tensors.append(elements)
            #     # state = next_state
            #     # self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)

            # # for i, tensor in enumerate(result_tensors):
            # #     print(env.graph_types[i], torch.mean(tensor))


            # averages_dict = {}
            # for index, elements_list in result_dict.items():
            #     combined_tensor = torch.cat(elements_list, dim=0)
            #     average_tensor = torch.mean(combined_tensor, dim=0)
            #     averages_dict[index] = average_tensor



    @property
    def class_name(self):
        return self.__class__.__name__


    
            
                    
    def estimate_all(self, num_steps = 1):    


        state = self.env.reset()
        num_steps = int(num_steps*1000/self.batch_size)


        for step in range(num_steps):
            action = self.get_action(state, 0)
            next_state, reward, _ = self.env.step(action, state)
            state = next_state
            self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
        
        estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
                 