from cmath import inf
import torch
import random
import numpy as np
import neptune
import csv
import os 


class Qtable:
    def __init__(self, env, seed = 42, lr=0.1, gamma = 0.9):  # eps_start, eps_end, eps_decay


        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.env = env
        self.batch_size = env.batch_size # has to be 1 for Qtable
        if self.batch_size > 1:
            raise ValueError('The environment batch size has to be equal to 1')
        self.num_nodes = self.env.num_nodes
        self.num_servers = self.env.num_servers
        self.lr = lr
        self.gamma = gamma
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = self.env.device 
        self.q_table = torch.zeros([self.num_nodes] * (self.num_servers + 2)).to(self.device)
        # print(self.q_table.size())
        # epsilon
        self.max_epsilon=1
        self.min_epsilon = 0.05 
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        
        
    def select_action(self, state, epsilon=0.1):
        if random.random() < epsilon:
            if self.env.request_same_node: 
              if state[-1] in state[:-1]:  
                action = state[-1].unsqueeze(0).unsqueeze(0)
              else: 
                action = state[:self.num_servers][random.randint(0, self.num_servers-1)].unsqueeze(0).unsqueeze(0)

            else: 
              action = state[:self.num_servers][random.randint(0, self.num_servers-1)].unsqueeze(0).unsqueeze(0)

            return action 
        else:  
          
            server_locations = state[:self.num_servers]
            qt_index = tuple(state)
            non_zero_values = self.q_table[qt_index][server_locations][self.q_table[qt_index][server_locations] != 0]
            non_zero_indices = torch.nonzero(self.q_table[qt_index][server_locations])

            if non_zero_values.numel() > 0:
            #   # Find the index of the highest non-zero value
            #   print(True)
            #   print("non_zero_values shape:", non_zero_values.shape)
            #   print("non_zero_indices shape:", non_zero_indices.shape)
                action_index = non_zero_indices[non_zero_values.argmax()].item()
            else:
                # If all values are zero, output index of zero
                action_index = 0
        
            # action_index = torch.argmax(self.q_table[qt_index][server_locations])

            action = server_locations[action_index].unsqueeze(0).unsqueeze(0)

            if self.env.request_same_node: 
                if state[-1] in state[:-1]:  
                    action = state[-1].unsqueeze(0).unsqueeze(0)
         
            return action

    def update_q_table(self, state, action, next_state, reward):
        # q_next
        qt_index_next = tuple(next_state)
        server_locations_next = next_state[:self.num_servers] 

        # q_next = self.q_table[qt_index_next][server_locations_next].max()

        # Get non-zero values
        non_zero_values = self.q_table[qt_index_next][server_locations_next][self.q_table[qt_index_next][server_locations_next] != 0]

        # Check if there are non-zero values
        if non_zero_values.numel() > 0:
            q_next = non_zero_values.max().item()
        else:
            q_next = 0
        # print(self.q_table[qt_index_next][server_locations_next])
        # print(q_next)
        if self.env.request_same_node:
          if next_state[-1] in next_state[:-1]:
              q_next = 0
        # print(next_state)
        # print(q_next)
        q_update = reward + self.gamma * q_next
        # print(q_update)
 
        # q_current
        q_current = self.q_table[tuple(state)][action]
        # print(q_current)
        self.q_table[tuple(state)][action] += self.lr * (q_update - q_current)
    
    def optimize(self, num_steps=100, epsilon = 0.5, display_results = False, explr = 0.6, print_results = False, decay_rate = 0.0005, save_results = False):
      
      state = self.env.reset().squeeze(0).long().to(self.device) 
      steps_for_display = int(10000/self.batch_size)
      # steps_for_display = 1
      num_steps = int(num_steps*1000/self.batch_size)

      initial_percentage = explr  
      initial_limit = int(num_steps * initial_percentage)  
      lr_dcr_step = int(num_steps * 0.9)  
    #   print(state)
      best_estimate = -inf
      if display_results:
          self.run = neptune.init_run(
              project="iliyasbektas/kserver",
              api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
          )  # your credentials 
          self.run["agent"] = self.class_name
          self.run["num_nodes"] = self.env.num_nodes
          self.run["graph_type"] = self.env.graph_type
          self.run["gamma"] = self.gamma
          self.run["lr"] = self.lr

      if save_results: 
        if not os.path.exists(f'results/gen_testing/{self.class_name}'):  
            os.makedirs(f'results/gen_testing/{self.class_name}')
        if not os.path.exists(f'results/gen_testing/{self.class_name}/models'):
            os.makedirs(f'results/gen_testing/{self.class_name}/models')  
        if not os.path.exists(f'results/gen_testing/{self.class_name}/train_results/models'):
            os.makedirs(f'results/gen_testing/{self.class_name}/train_results/models')  
        if not os.path.exists(f'results/gen_testing/{self.class_name}/train_results/raw_results'):
            os.makedirs(f'results/gen_testing/{self.class_name}/train_results/raw_results')
      q_table_old = self.q_table.clone()
      for step in range(num_steps):
          if epsilon == 'epsilon_decay':
                epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
          else:
              if step < initial_limit:
                  epsilon = 0.5
              else: 
                  epsilon = 0.1
              # if step < lr_dcr_step:
              #     self.lr = 0.001

          # if epsilon == 'epsilon_decay':
          #   epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
          # else:
          #   epsilon = epsilon
          action = self.select_action(state, epsilon).to(self.device)
          next_state, reward, _ = self.env.step(action, state.unsqueeze(0))
          # print(state, action, reward, next_state)
          self.update_q_table(state, action, next_state.squeeze(0).long().to(self.device), reward.to(self.device))
          state = next_state.squeeze(0).long().to(self.device)
    
          self.total_reward = torch.cat((self.total_reward, reward), 0)
          
          if print_results:
            if ((step+1)  % steps_for_display == 0):
                step_estimate = self.estimate(40)
                print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[-steps_for_display:]):.2f}, Estimate {step_estimate[0]:.2f}")
                
                # find the average change
                diff_qtable = self.q_table - q_table_old
                diff_qtable = diff_qtable[diff_qtable != 0]   
                # print(diff_qtable)
                print(f"Step {step+1}: average Q change =  {torch.mean(abs(diff_qtable)):.2f}")
                q_table_old = self.q_table.clone()


          if display_results:
            if ((step+1)  % steps_for_display == 0):
              self.run["Average_Reward"].append(torch.mean(self.total_reward[-steps_for_display:]))
              if print_results: 
                  self.run["Estimate"].append(step_estimate[0]) 
              else: 
                  self.run["Estimate"].append(self.estimate(40)[0]) 

          if save_results: 
            if ((step+1)  % steps_for_display == 0):
                if print_results == True or display_results==True: 
                    estimate, q1, q3, _ = step_estimate
                else: 
                    estimate, q1, q3, _ = self.estimate(40)

                    
                output_file_name = f'results/gen_testing/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_gamma{self.gamma}.csv'
                # torch.save(self.q_network.state_dict(), f'results/gen_testing/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_gamma{self.gamma}.pth')
                if estimate > best_estimate:
                  torch.save(self.q_table, f'results/gen_testing/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_gamma{self.gamma}.pth')
                with open(output_file_name, 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'estimate', 'q1', 'q3'])
                    writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])  
                if estimate > best_estimate: 
                   best_estimate = estimate       
                # print(f"{best_estimate: .2f}" )     
           
      if display_results:
          self.run.stop()

      if print_results:
            print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")


    
    def estimate(self, num_steps = 1):

        state = self.env.reset().squeeze(0).long()
        num_steps = int(num_steps*1000/self.batch_size)

        for step in range(num_steps):
          action = self.select_action(state, 0).to(self.device)
          next_state, reward, _ = self.env.step(action, state.unsqueeze(0))
          state = next_state.squeeze(0).long()
          self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
        
        estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates

    
    def estimate_seq(self, state, requests):

        total_reward_episode = []
        for i in range(requests.shape[1]): 
            action = self.select_action(state.squeeze(0).long(), 0).to(self.device)
            next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
            # print(state, action, reward)
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0],1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates

        
        # print(f"Step {step+1}, Average Estimate Reward {torch.mean(self.total_reward_estimate[1:]):.2f}")
    @property
    def class_name(self):
        return self.__class__.__name__

      