import torch
import random
import networkx as nx
import math
import numpy as np


class Qtable:
    def __init__(self, env, lr=0.1, device = 'cuda', gamma = 0.9):  # eps_start, eps_end, eps_decay
        self.env = env
        self.batch_size = env.batch_size # has to be 1 for Qtable
        if self.batch_size > 1:
            print('The environment batch size has to be equal to 1')
        self.num_nodes = self.env.num_nodes
        self.num_servers = self.env.num_servers
        self.lr = lr
        self.gamma = gamma
        if device == 'cuda':
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        self.q_table = torch.zeros([self.num_nodes] * (self.num_servers + 2)).to(self.device)
        # epsilon
        self.max_epsilon=1
        self.min_epsilon = 0.05 
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        
        
    def select_action(self, state, epsilon=0.1):
        if random.random() < epsilon:
          return state[:self.num_servers][random.randint(0, self.num_servers-1)].unsqueeze(0).unsqueeze(0)
        else:
          server_locations = state[:self.num_servers]
          qt_index = tuple(state)
          action_index = torch.argmax(self.q_table[qt_index][server_locations])
          action = server_locations[action_index].unsqueeze(0).unsqueeze(0)
          return action

    def update_q_table(self, state, action, next_state, reward):
        # q_next
        qt_index_next = tuple(next_state)
        server_locations_next = next_state[:self.num_servers]
        q_next = self.q_table[qt_index_next][server_locations_next].max()
        q_update = reward + self.gamma * q_next
        # q_current
        q_current = self.q_table[tuple(state)][action]
        self.q_table[tuple(state)][action] += self.lr * (q_update - q_current)
    
    def optimize(self, num_steps=100, epsilon_decay = True,  display_results = False, decay_rate = 0.0005):

        state = self.env.reset().squeeze(0).long().to(self.device) 
        steps_for_display = int(10000/self.batch_size)
        num_steps = int(num_steps*1000/self.batch_size)

        for step in range(num_steps):
          if epsilon_decay == True:
            epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
          else:
            epsilon = 0.5
          action = self.select_action(state, epsilon).to(self.device)
          next_state, reward, _ = self.env.step(action, state.unsqueeze(0))
          self.update_q_table(state, action, next_state.squeeze(0).long().to(self.device), reward.to(self.device))
          state = next_state.squeeze(0).long().to(self.device)
    
          self.total_reward = torch.cat((self.total_reward, reward), 0)
          if ((step+1)  % steps_for_display == 0) == True:
            print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(steps_for_display):.2f}")
            if display_results == True:
                run["Average_Reward"].append(torch.mean(self.total_reward[self.env.num_servers:]))
                run["Estimate"].append(self.estimate(steps_for_display)) 

        print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")
        
    
    
    def estimate(self, num_steps = 1000):

        state = self.env.reset().squeeze(0).long()

        for step in range(num_steps):
          action = self.select_action(state, 0).to(self.device)
          next_state, reward, _ = self.env.step(action, state.unsqueeze(0))
          state = next_state.squeeze(0).long()
          self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
        
        return torch.mean(self.total_reward_estimate[-(num_steps*self.batch_size):])

        
        # print(f"Step {step+1}, Average Estimate Reward {torch.mean(self.total_reward_estimate[1:]):.2f}")

      