import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np




class DQNAgent():
    
    def __init__(self, env, seed = 42, gamma=0.99, lr=0.001, memory_size=10000):


        seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.env = env
        self.gamma = gamma
        self.lr = lr
        self.batch_size = env.batch_size
        self.device = self.env.device 
        # Initialize replay memory to capacity N(memory_size)
        self.memory = deque(maxlen=memory_size)
        # Initialize action-value funciton Q
        self.q_network = self.create_network().to(self.device)
        # Initialize target action-value function Q'
        self.target_network = self.create_network().to(self.device)
        # every
        self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        # epsilon
        self.max_epsilon=1
        self.min_epsilon = 0.05 
        
        
    def create_network(self):
        # input_size = (self.env.num_servers + 1)
        # output_size = self.env.num_nodes
        input_size = self.env.num_nodes
        output_size = self.env.num_servers
        hidden_size = 128
        return nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
    
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()

        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")

    def observation_formation(self, state): 
        C = torch.zeros(self.env.batch_size, self.env.num_nodes)
        for i in range(self.env.batch_size):
            # use row i of qt_index as an index to extract a row from q_values
            C[i][state[i][:-1].long()] = 1
            C[i][state[i][-1].long()] = -0.5
        return C

        
    def get_action(self, state, epsilon=0.1):
        # with probability epsilon 
        if random.random() < epsilon:
            random_indices = torch.randint(low = 0, high=self.env.num_servers, size=(self.batch_size,)).to(self.device)
            action_batch = torch.gather(state[:, :self.env.num_servers], dim=1, index=random_indices.unsqueeze(1))
            return action_batch.long()
        else:
            with torch.no_grad():

                qt_index = state[:,:self.env.num_servers]
                C = self.observation_formation(state).to(self.device)
                q_values = self.q_network(C)

                max_index = torch.argmax(q_values, dim = 1)

                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))
                return action_batch



    def update(self):
        

        batch = random.sample(self.memory, self.batch_size)
        concatenated = [torch.cat(tensors, dim=0) for tensors in zip(*batch)]
        state_batch, action_batch, reward_batch, next_state_batch = concatenated[0], concatenated[1], concatenated[2], concatenated[3]
        action_indices = torch.argmax((state_batch == action_batch).int(), dim =1)

        C = self.observation_formation(state_batch).to(self.device)
        q_values = self.q_network(C)
        q_values = torch.gather(q_values, 1, action_indices.long().view(-1, 1))

        C_next = self.observation_formation(next_state_batch).to(self.device)
        next_q_values = self.target_network(C_next).max(1)[0].unsqueeze(1)
        expected_q_values = reward_batch + self.gamma * next_q_values 
        loss = F.mse_loss(q_values, expected_q_values)

        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network
        self.soft_update_target_network()

        
    def soft_update_target_network(self, tau=0.01):
        for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)
        
        # The remember function stores the current state, selected action, received reward, 
        # next state
        
    def remember(self, state, action, reward, next_state):
      for i in range(self.batch_size):
        self.memory.append((state[i].unsqueeze(0), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[i].unsqueeze(0)))
  
    

    
    
    def optimize(self, num_steps=100, epsilon = 0.5, display_results = False, print_results = False, decay_rate = 0.0005):
        
            state = self.env.reset()
            steps_for_display = int(10000/self.batch_size)
            num_steps = int(num_steps*1000/self.batch_size)
             
            for step in range(num_steps):
                if epsilon == 'epsilon_decay':
                  epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
                else:
                  epsilon = epsilon
                action = self.get_action(state.to(self.device), epsilon).to(self.device)
                next_state, reward, _ = self.env.step(action, state.to(self.device))
            
                self.remember(state.to(self.device),
                action,
                reward.to(self.device),
                next_state.to(self.device),
                )
                
                self.update()
                
                state = next_state
                self.total_reward = torch.cat((self.total_reward, reward), 0)
                if print_results == True:
                        if ((step+1)  % steps_for_display == 0) == True:
                            print(f"Step {step+1}, Epsilon {epsilon:.2f}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")
                            if display_results == True:
                                run["Average_Reward"].append(torch.mean(self.total_reward[self.env.num_servers:]))
                                run["Estimate"].append(self.estimate(steps_for_display)) 

            if print_results == True:
                print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")
    

    def estimate(self, num_steps = 1):

        state = self.env.reset()
        num_steps = int(num_steps*1000/self.batch_size)

        for step in range(num_steps):
          action = self.get_action(state.to(self.device), epsilon = 0)
          next_state, reward, _ = self.env.step(action, state)
          state = next_state
          self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)

        estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
        
        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75)

        
        # print(f"Step {step+1}, Average Estimate Reward {torch.mean(self.total_reward_estimate[-num_steps:]):.2f}")
                
                            