import torch
import random


class RandomPolicy():
        
  def __init__(self, env):
    self.env = env
    self.batch_size = env.batch_size # has to be 1 for Qtable
    if self.batch_size > 1:
            raise ValueError('The environment batch size has to be equal to 1') 
    self.device = self.env.device 
    self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 


  def estimate(self, num_steps=1, print_results = False):
    
    state = self.env.reset().to(self.device) 
    steps_for_display = int(10000/self.batch_size)
    num_steps = int(num_steps*1000/self.batch_size)
    # print("Initial state:\n", state)
    for step in range(num_steps):
        # if print_results == True:
        #         print(state)   
        # print(state)
        action = state[0, :-1][random.randint(0, self.env.num_servers-1)].unsqueeze(0).unsqueeze(0).to(self.device) 

        if self.env.request_same_node: 
          for i in range(self.env.batch_size):
              if state[i][-1] in state[i][:-1]:  
                  action[i] = state[i][-1] 
        next_state, reward, _ = self.env.step(action, state)
        
        self.total_reward = torch.cat((self.total_reward, reward.to(self.device) ), 0)
        state = next_state.to(self.device) 
        if print_results:
                # print(state, next_state, action)
                if ((step+1)  % steps_for_display == 0) == True:
                        print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")
                        
    if print_results:
        print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")

    estimates = self.total_reward[self.env.num_servers:]
    
    return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates



class GreedyPolicy():
  def __init__(self, env):
    self.env = env
    # print(self.env.probabilities)
    # print(self.env.cost_matrix)
    self.batch_size = env.batch_size # has to be 1 for Qtable
    if self.batch_size > 1:
            raise ValueError('The environment batch size has to be equal to 1') 
    self.device = self.env.device 
    self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 

  def estimate(self, num_steps=1,  print_results = False):

    state = self.env.reset().to(self.device) 
    steps_for_display = int(10000/self.batch_size)
    num_steps = int(num_steps*1000/self.batch_size)
    # print(num_steps)
    # print(state)
    # print("Initial state:\n", state)
    for step in range(num_steps):
        # if step == 1:
        #   print(state)
        actions_all = state[:, :-1]
        min_index = torch.argmin(self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()])
        action = actions_all[:, min_index]
        if self.env.request_same_node: 
          for i in range(self.env.batch_size):
              if state[i][-1] in state[i][:-1]:  
                  action[i] = state[i][-1] 
        
        next_state, reward, _ = self.env.step(action.unsqueeze(0), state)
        # print(state, action, reward, next_state)
        self.total_reward = torch.cat((self.total_reward, reward.to(self.device) ), 0)
        state = next_state.to(self.device) 
        if print_results:
          if ((step+1)  % steps_for_display == 0) == True:
                  print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")


    if print_results:
        print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")

    estimates = self.total_reward[self.env.num_servers:]
    # print(estimates[-100:])
    return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates


  def estimate_seq(self, state, requests):

      total_reward_episode = []

      for i in range(requests.shape[1]): 

          actions_all = state[:, :-1]
          min_index = torch.argmin(self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()])
          # print(self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()])
          action = actions_all[:, min_index] 
          action = action.unsqueeze(0)
          # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
          next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
          # print(state, action, reward)
          state = next_state
          total_reward_episode.append(reward.reshape(state.shape[0],1))

      estimates = torch.cat(total_reward_episode, dim=1)
      
      return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
