import torch
import random
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import random



class BalancePolicy():
  def __init__(self, env):
    self.env = env
    self.batch_size = env.batch_size 
    if self.batch_size > 1:
          raise ValueError('The environment batch size has to be equal to 1') 
    if self.env.balanced_algorithm == False:
          raise ValueError('This argument has to be True') 
    self.device = self.env.device 
    self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
    self.total_distance = torch.zeros((self.env.num_servers)).to(self.device)
    # self.server_indices = torch.Tensor([i for i in range((self.env.num_servers))])
    # state = agent.env.reset().to(agent.device) 

  def estimate(self, state = 1, requests = 1, num_steps=1,  print_results = False):

    if self.env.seq_req == False: 
      state = self.env.reset().to(self.device) 
      steps_for_display = int(10000/self.batch_size)
      num_steps = int(num_steps*1000/self.batch_size)
      
      
    
    if self.env.seq_req:
      num_steps = requests.shape[1]

 

    server_indices = torch.Tensor([i for i in range((self.env.num_servers))])
    # print("Initial state:\n", state)
    for step in range(num_steps):

        actions_all = state[:, :-1]
        min_dists = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]
        # min_index = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]

        tot_dist_sum = torch.zeros((self.env.num_servers)).to(self.device)
        for i in range(self.env.num_servers): 
            tot_dist_sum[int(server_indices[i])] = self.total_distance[int(server_indices[i])] + min_dists[i]

        min_server_index = torch.argmin(tot_dist_sum)
        min_index = torch.where(server_indices == min_server_index)[0]

        self.total_distance[min_server_index] += min_dists[min_index].squeeze(0)    

        
        action = actions_all[:, min_index]
        
        if self.env.request_same_node: 
          for i in range(self.env.batch_size):
              if state[i][-1] in state[i][:-1]:  
                  action[i] = state[i][-1] 
        if self.env.seq_req == False: 
          next_state, reward, _, next_server_indices = self.env.step(action, state, server_indices)
        else: 
          next_state, reward, _, next_server_indices = self.env.step(action, state, server_indices, next_req = requests[:, step].reshape(state.shape[0],1))
        
        
        self.total_reward = torch.cat((self.total_reward, reward.to(self.device) ), 0)
        state = next_state.to(self.device) 
        server_indices = next_server_indices.to(self.device) 
        if self.env.seq_req == False:
          if print_results:
            if ((step+1)  % steps_for_display == 0):
              print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")


    
    if print_results:
      print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")
      print(f"Servers Distances Traveled: {self.total_distance} ")

    estimates = self.total_reward[self.env.num_servers:]
    return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates


  def estimate_seq(self, state, requests, print_results = False):

    total_reward_episode = []
    server_indices = torch.Tensor([i for i in range((self.env.num_servers))])
    for step in range(requests.shape[1]): 
        actions_all = state[:, :-1]
        min_dists = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]
        # min_index = self.env.cost_matrix[actions_all.squeeze().long(), state[:, -1].long()]

        tot_dist_sum = torch.zeros((self.env.num_servers)).to(self.device)
        for i in range(self.env.num_servers): 
            tot_dist_sum[int(server_indices[i])] = self.total_distance[int(server_indices[i])] + min_dists[i]

        min_server_index = torch.argmin(tot_dist_sum)
        min_index = torch.where(server_indices == min_server_index)[0]

        self.total_distance[min_server_index] += min_dists[min_index].squeeze(0)    

        
        action = actions_all[:, min_index]
        # print(state, state.size(), action, action.size(), requests[:, step].reshape(state.shape[0],1), requests[:, (step)].reshape(state.shape[0],1).size()) 
        next_state, reward, _, next_server_indices = self.env.step(action, state, server_indices, next_req = requests[:, step].reshape(state.shape[0],1))
        state = next_state.to(self.device) 
        server_indices = next_server_indices.to(self.device) 
        total_reward_episode.append(reward.reshape(state.shape[0],1))

    if print_results:
      # print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")
      print(f"Servers Distances Traveled: {self.total_distance} ")
    
    estimates = torch.cat(total_reward_episode, dim=1)
    return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates






