import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import neptune

import csv
import os 
import time


class DQNAgent_10():
    def __init__(self, env, seed = 42, gamma=0.99, lr=0.01, memory_size=10000, num_layers = 12):

        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.env = env
        self.uniform_random = self.env.uniform_random 
        self.var_distance = self.env.var_distance
        self.gamma = gamma
        self.lr = lr
        self.batch_size = self.env.batch_size
        self.device = self.env.device 
        # Initialize replay memory to capacity N(memory_size)
        self.memory = deque(maxlen=memory_size)
        # Initialize action-value funciton Q
        self.num_layers = num_layers

        self.q_network = self.create_network().to(self.device)
        # Initialize target action-value function Q'
        self.target_network = self.create_network().to(self.device)
        # every
        self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
        # total rewards for separate sets 
        self.total_reward =[] 
        # epsilon
        self.max_epsilon=1
        self.min_epsilon=0.05 
        
        
    def create_network(self):
        # input_size = (self.env.num_servers + 1)
        # output_size = self.env.num_nodes
        input_size = 2*self.env.num_nodes
        output_size = self.env.num_nodes
        hidden_size = self.env.num_nodes*12
        return nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

        
    # def create_network(self): 

    #     layers = []
    #     input_size = self.env.num_nodes
    #     output_size = self.env.num_nodes
    #     hidden_size = 128
    #     layers.append(nn.Linear(input_size, hidden_size))
    #     layers.append(nn.ReLU())
        
    #     for _ in range(self.num_layers - 1):
    #         layers.append(nn.Linear(hidden_size, hidden_size))
    #         layers.append(nn.ReLU())
        
    #     layers.append(nn.Linear(hidden_size, output_size))
        
    #     return nn.Sequential(*layers)

    def total_params(self):
        return sum(p.numel() for p in self.q_network.parameters())
        
    def count_parameters(self): 
        return sum(p.numel() for p in self.q_network.parameters() if p.requires_grad)
    
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()

        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")
            
    def observation_formation(self, state):     
       
        # C = torch.zeros(self.env.batch_size, self.env.num_nodes)
        # for i in range(self.env.batch_size):
        #     C[i][state[i][:-1].long()] = 1  # Servers on the nodes
        #     request_node = state[i][-1].long()
            
        #     if C[i][request_node] == 1:
        #         C[i][request_node] = 0.5  # Request on a node with a server
        #     else:
        #         C[i][request_node] = -0.5  # Request on a node without a server
        
        C = torch.zeros(self.env.batch_size, 2, self.env.num_nodes).to(self.device)
        for i in range(self.env.batch_size):
            # use row i of qt_index as an index to extract a row from q_values
            C[i][0][state[i][:-1].long()] = 1
            C[i][1][state[i][-1].long()] = 1
        
        
        return C

        
    def get_action(self, state, epsilon=0.1):
        # with probability epsilon 
        if random.random() < epsilon:
            
            random_indices = torch.randint(low = 0, high=self.env.num_servers, size=(self.batch_size,)).to(self.device)
            action_batch = torch.gather(state[:, :self.env.num_servers], dim=1, index=random_indices.unsqueeze(1))
            if self.env.request_same_node: 
                for i in range(self.env.batch_size):
                    if state[i][-1] in state[i][:-1]:  
                        action_batch[i] = state[i][-1]
            return action_batch.long()
        else:
            with torch.no_grad():
                
                qt_index = state[:,:self.env.num_servers]
                obs = self.observation_formation(state).to(self.device)
                obs = obs.view(obs.size(0), -1)
                q_values = self.q_network(obs)

                # create empty tensor C of size NxM
                C = torch.zeros_like(qt_index).to(self.device)

                # loop through each row of qt_index
                for i in range(qt_index.shape[0]):
                    # use row i of qt_index as an index to extract a row from q_values
                    row_b = q_values[i, qt_index[i].long()]
                    # assign the extracted row to the corresponding row in C
                    C[i] = row_b

                max_index = torch.argmax(C, dim =1)
                action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))

                if self.env.request_same_node: 
                    for i in range(self.env.batch_size):
                        if state[i][-1] in state[i][:-1]:  
                            action_batch[i] = state[i][-1]

                return action_batch  



    # The update function performs a single update step 
    # of the Q-network. It randomly samples a batch of transitions 
    # from the memory, calculates the expected Q-values using the 
    # target network, and updates the Q-values using the 
    # mean-squared error loss between the predicted Q-values and 
    # expected Q-values. The function also updates the target network
    # using a soft update technique.
    def update(self):
        
        batch = random.sample(self.memory, self.batch_size)
        concatenated = [torch.cat(tensors, dim=0) for tensors in zip(*batch)]
        state_batch, action_batch, reward_batch, next_state_batch = concatenated[0], concatenated[1], concatenated[2], concatenated[3]
        
        obs = self.observation_formation(state_batch).to(self.device)
        obs = obs.view(obs.size(0), -1)
        # print(obs.size())
        q_values = self.q_network(obs).gather(1, action_batch.long())

        next_qt_index = next_state_batch[:,:self.env.num_servers]
        next_obs = self.observation_formation(next_state_batch).to(self.device)
        next_obs = next_obs.view(next_obs.size(0), -1)
        next_q_values = self.target_network(next_obs)

        # create empty tensor C of size NxM
        C = torch.zeros_like(next_qt_index).to(self.device)
        # loop through each row of qt_index
        for i in range(next_qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = next_q_values[i, next_qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b

        next_q_values = C.max(1)[0].unsqueeze(1)

        expected_q_values = reward_batch.view(-1, 1) + self.gamma * next_q_values 
        loss = F.mse_loss(q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update target network
        self.soft_update_target_network()

        
    def soft_update_target_network(self, tau=0.01):
        for target_param, q_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_param.data.copy_(tau * q_param.data + (1 - tau) * target_param.data)
        
        # The remember function stores the current state, selected action, received reward, 
        # next state
        
    def remember(self, state, action, reward, next_state):
      for i in range(self.batch_size):
        self.memory.append((state[i].unsqueeze(0), action[i].unsqueeze(0), reward[i].unsqueeze(0).unsqueeze(0), next_state[i].unsqueeze(0)))
  
    

    # The train function trains the agent for a specified number of steps. 
    # It repeatedly selects actions, observes the next state and reward, 
    # stores the transition in the memory buffer, and updates the Q-network 
    # using the update function. The function also calculates the average reward 
    # over the last 1000 steps and prints it out periodically.
    
    def optimize(self, num_steps=100, estimate_steps = 20, epsilon = 0.5, explr = 0.6, display_results = False, print_results = False, decay_rate = 0.0005, save_results = False):

        state = self.env.reset()
        steps_for_display = int(10000/self.batch_size)
        num_steps = int(num_steps*1000/self.batch_size)

        estimate_steps = int(estimate_steps*1000/self.batch_size)

        
        initial_percentage = explr  
        initial_limit = int(num_steps * initial_percentage)  
        lr_dcr_step = int(num_steps * 0.9)  

        if display_results:
            self.run = neptune.init_run(
                project="iliyasbektas/kserver",
                api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiOTRhNmFlNi0xMzU0LTRiNGUtODZmYy05ZWQyMDA4ZjJiZDQifQ==",
            )  # your credentials 
            self.run["agent"] = self.class_name
            self.run["num_nodes"] = self.env.num_nodes
            self.run["graph_type"] = self.env.graph_type
            self.run["gamma"] = self.gamma
            self.run["lr"] = self.lr

        if save_results: 
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}'):  
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}')
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/models'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/models')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models')  
            if not os.path.exists(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results'):
                os.makedirs(f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/raw_results')
        

        start_time = time.time()
        best_estimate = float('-inf')
        for step in range(num_steps):
            if epsilon == 'epsilon_decay':
                epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon)*np.exp(-decay_rate*step)
            else:
                if step < initial_limit:
                    epsilon = 0.5
                else: 
                    epsilon = 0.1
                if step < lr_dcr_step:
                    self.lr = 0.001
            action = self.get_action(state.to(self.device), epsilon).to(self.device)
            next_state, reward, _ = self.env.step(action, state.to(self.device))
        
            self.remember(state.to(self.device),
            action,
            reward.to(self.device),
            next_state.to(self.device),
            )
            
            self.update()
            
            state = next_state
            self.total_reward.append(reward)
            if ((step+1)  % steps_for_display == 0):
                if print_results:    
                        elapsed_time = time.time() - start_time
                        start_time = time.time()
                        step_estimate = self.estimate(estimate_steps)
                        average_reward = torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1))
                        step1000 = self.round_to_nearest_10000((step+1)*self.batch_size)
                        print(f"Step {step1000}, Epsilon {epsilon:.2f}, Average Reward {average_reward:.2f}, Estimate {step_estimate[0]:.2f}, Time Taken: {elapsed_time:.2f} seconds")
            
                if display_results:
                
                    average_reward = torch.mean(torch.cat(self.total_reward[(-estimate_steps*self.batch_size):]).view(-1))
                    self.run["Average_Reward"].append(average_reward)
                    if print_results:  
                        self.run["Estimate"].append(step_estimate[0]) 
                    else: 
                        self.run["Estimate"].append(self.estimate(estimate_steps)[0]) 
            
            
                if save_results: 
                
                    if print_results == True or display_results ==True: 
                        estimate, q1, q3, _ = step_estimate
                    else: 
                        estimate, q1, q3, _ = self.estimate(40)

                        
                    output_file_name = f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/results_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_gamma{self.gamma}_vd{self.var_distance}.csv'
                    if estimate.item() > best_estimate:
                        torch.save(self.q_network.state_dict(), f'results/gen_testing/VD{self.var_distance}/{self.class_name}/train_results/models/model_{self.env.graph_type}_{self.class_name}_{self.env.num_nodes}_{self.env.num_servers}_gamma{self.gamma}_vd{self.var_distance}.pth')
                        best_estimate = estimate.item()
                        
                    with open(output_file_name, 'w', newline='') as f:
                            writer = csv.writer(f)
                            writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'seed', 'estimate', 'q1', 'q3'])
                            writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.seed, round(estimate.item(), 3), round(q1.item(), 3), round(q3.item(), 3)])  

                    # if self.step > int(num_steps * 0.95): 
                    #     self.run.stop()
            # self.step = step 
        
        if display_results:
            self.run.stop()

        if print_results:
            print(f"Step {self.round_to_nearest_10000((step+1)*self.batch_size)}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")

      


    def estimate(self, num_steps = 1, episode = False):

        state = self.env.reset()
        num_steps = int(num_steps*1000/self.batch_size)
        reward_list = []
        for step in range(num_steps):
            action = self.get_action(state.to(self.device), epsilon = 0)
            next_state, reward, _ = self.env.step(action, state)
            state = next_state
            reward_list.append(reward.reshape(self.env.batch_size, 1))
        estimates = torch.cat(reward_list, dim=1)    
        if episode:  
            return torch.mean(estimates, 1), torch.quantile(estimates, 0.25, 1), torch.quantile(estimates, 0.75, 1), estimates
        else: 
            estimates = estimates.view(-1)
            return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
        
        
        

    def estimate_seq(self, state, requests):

        total_reward_episode = []
        num_requests = requests.shape[1]
        progress_interval = num_requests // 10  # 10% of requests
        start_time = time.time()

        for i in range(num_requests): 
            # Print progress every 10%
            if i % progress_interval == 0 and i != 0:
                elapsed_time = time.time() - start_time
                print(f"{(i / num_requests) * 100:.0f}% of requests finished and it took {elapsed_time:.2f} seconds")

        
        
            # print(f'request{i}')
            action = self.get_action(state.to(self.device), epsilon = 0)
            # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
            next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0],1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates
                
    def round_to_nearest_10000(self, number):
        return round(number / 10000) * 10000


    @property
    def class_name(self):
        return self.__class__.__name__