import torch
import random
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import random




class HarmonicPolicy():
  def __init__(self, env):
    self.env = env
    self.batch_size = env.batch_size # has to be 1 for Qtable
    if self.batch_size > 1:
            raise ValueError('The environment batch size has to be equal to 1') 
    self.device = self.env.device 
    self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 

  def estimate(self, num_steps=1,  print_results = False):

    state = self.env.reset().to(self.device) 
    steps_for_display = int(10000/self.batch_size)
    num_steps = int(num_steps*1000/self.batch_size)
    
    for step in range(num_steps):

        
        
        actions_all = state[:, :-1]
        min_dists = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]
        # Calculate the inverse of the distances
        inverse_distances = 1.0 / min_dists
        # Normalize the probabilities to ensure they sum up to 1
        normalized_probabilities = inverse_distances / torch.sum(inverse_distances)
        # min_index = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]
        try: 
          selected_server_index = torch.multinomial(normalized_probabilities, 1, replacement=True)
          action = actions_all[:, selected_server_index]
        except: 
          action = state[:, -1]
          # action = torch.zeros(self.batch_size, 1).to(self.device) 
          # for i in range(self.env.batch_size):
          #     if state[i][-1] in state[i][:-1]:  
          #         action[i] = state[i][-1] 

          
        next_state, reward, _ = self.env.step(action, state)
        self.total_reward = torch.cat((self.total_reward, reward.to(self.device) ), 0)
        state = next_state.to(self.device) 
        if print_results:
          if ((step+1)  % steps_for_display == 0) == True:
                  print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")


    if print_results:
        print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")

    estimates = self.total_reward[self.env.num_servers:]
    return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates


  def estimate_seq(self, state, requests):

      total_reward_episode = []

      for i in range(requests.shape[1]): 
        
          # print(f'request{i}')
        actions_all = state[:, :-1]
        min_dists = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]
        # Calculate the inverse of the distances
        inverse_distances = 1.0 / min_dists
        # Normalize the probabilities to ensure they sum up to 1
        normalized_probabilities = inverse_distances / torch.sum(inverse_distances)
        # min_index = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]
        try: 
          selected_server_index = torch.multinomial(normalized_probabilities, 1, replacement=True)
          action = actions_all[:, selected_server_index]
        except: 
          action = state[:, -1]
        # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
        next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
        state = next_state
        total_reward_episode.append(reward.reshape(state.shape[0],1))

      estimates = torch.cat(total_reward_episode, dim=1)
      return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates




