from torch_geometric.nn import GCNConv, global_add_pool
import torch.nn.functional as F
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import convert
import torch_geometric
from Policies.NET import Net






class GCN_SL():
    def __init__(self, env, model='results/all_results/models/model_cycle_DQNAgent_10_10_seed42.pth', seed = 42, lr=0.001, hidden_channels = 16):

        seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.env = env
        self.lr = lr
        self.batch_size = self.env.batch_size
        self.device = self.env.device
        self.model = model
        if round(self.env.average_hops()) == 1: 
            self.num_layers = 3 
        else: 
            self.num_layers = round(2.5*self.env.average_hops())
        # start_index = self.model.find('_') + 1
        # end_index = self.model.find('_DQNAgent_10')
        # self.graph_type = self.model[start_index:end_index]
        
        # Initialize action-value funciton Q
        self.q_network = Net(in_channels=2, hidden_channels = hidden_channels, out_channels=self.env.num_nodes, num_layers = (self.num_layers - 2)).to(self.device)
        # Initialize target action-value function Q'
        self.target_network = self.create_target_network().to(self.device)
        self.target_network.load_state_dict(torch.load(self.model))   
        # every
        # self.target_network.load_state_dict(self.q_network.state_dict())
        # optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
        # loss 
        self.loss = nn.CrossEntropyLoss()
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
        # edge index
        self.edge_index = convert.from_networkx(self.env.graph).edge_index


    def create_target_network(self):
        # input_size = (self.env.num_servers + 1)
        # output_size = self.env.num_nodes
        input_size = self.env.num_nodes
        output_size = self.env.num_nodes
        hidden_size = 128
        return nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
    
    def print_network_weights(self):
        network_weights = self.q_network.state_dict()

        # Print the weights of the network
        for name, param in network_weights.items():
            print(f"Layer: {name}\nWeights: {param}")
        
    
    def observation_formation(self, state): 
        X = torch.zeros(self.env.batch_size, 2, self.env.num_nodes).to(self.device)
        for i in range(self.env.batch_size):
            # use row i of qt_index as an index to extract a row from q_values
            X[i][0][state[i][:-1].long()] = 1
            X[i][1][state[i][-1].long()] = 1

        data_list = []
        for i in range(self.batch_size):
          x = X[i]
          res = Data(x = x.T, edge_index = self.edge_index)
          data_list.append(res)
        train_loader = torch_geometric.loader.DataLoader(data_list, batch_size=self.batch_size, shuffle=False)
        data = next(iter(train_loader))

        return data

    def target_observation_formation(self, state): 
        C = torch.zeros(self.env.batch_size, self.env.num_nodes).to(self.device)
        for i in range(self.env.batch_size):
            # use row i of qt_index as an index to extract a row from q_values
            C[i][state[i][:-1].long()] = 1
            C[i][state[i][-1].long()] = -0.5
        return C

        
    def get_action(self, state):

        # qt_index = state[:,:self.env.num_servers].to(self.device)
        # data = self.observation_formation(state).to(self.device)
        # q_values = self.q_network(data.x, data.edge_index, data.batch)
        

        # max_index = torch.argmax(q_values, dim = 1)

        # action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))

        # return action_batch, q_values

        qt_index = state[:,:self.env.num_servers].to(self.device)
        data = self.observation_formation(state).to(self.device)
        q_values = self.q_network(data.x, data.edge_index, data.batch)
        q_values = q_values.reshape(self.env.batch_size, -1)
        # create empty tensor C of size NxM
        C = torch.zeros_like(qt_index).to(self.device)

        # loop through each row of qt_index
        for i in range(qt_index.shape[0]):
            # use row i of qt_index as an index to extract a row from q_values
            row_b = q_values[i, qt_index[i].long()]
            # assign the extracted row to the corresponding row in C
            C[i] = row_b
            
        max_index = torch.argmax(C, dim =1)
        action_batch = torch.gather(qt_index, 1, max_index.view(-1, 1))
        return action_batch, C

    def get_target_values(self, state):
    
        with torch.no_grad():

            qt_index = state[:,:self.env.num_servers]
            obs = self.target_observation_formation(state).to(self.device)
            q_values = self.target_network(obs)

            # create empty tensor C of size NxM
            C = torch.zeros_like(qt_index).to(self.device)

            # loop through each row of qt_index
            for i in range(qt_index.shape[0]):
                # use row i of qt_index as an index to extract a row from q_values
                row_b = q_values[i, qt_index[i].long()]
                # assign the extracted row to the corresponding row in C
                C[i] = row_b

            max_index = torch.argmax(C, dim =1)

            return max_index


  
    
    
    def optimize(self, num_steps=200,display_results = False, print_results = False):

        state = self.env.reset()   
        steps_for_display = int(10000/self.batch_size)
        num_steps = int(num_steps*1000/self.batch_size)

        for step in range(num_steps):
            
            action, y_pred = self.get_action(state.to(self.device))
            target = self.get_target_values(state.to(self.device))
            next_state, reward, _ = self.env.step(action, state)

            loss = self.loss(y_pred, target).to(self.device)       
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            state = next_state
            self.total_reward = torch.cat((self.total_reward, reward), 0)
            # print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")
            if print_results == True:
                if ((step+1)  % steps_for_display == 0) == True:
                    print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")
                    if display_results == True:
                        run["Average_Reward"].append(torch.mean(self.total_reward[1:]))
                        run["Estimate"].append(self.estimate()) 

        if print_results == True:
            print(f"Step {step+1}, Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}, Estimate {self.estimate(40)[0]:.2f}")
    

    def estimate(self, num_steps = 1):

        state = self.env.reset()
        num_steps = int(num_steps*1000/self.batch_size)

        for step in range(num_steps):
            action = self.get_action(state.to(self.device), epsilon = 0)[0]
            next_state, reward, _ = self.env.step(action, state)
            state = next_state
            self.total_reward_estimate = torch.cat((self.total_reward_estimate, reward), 0)
        
        estimates = self.total_reward_estimate[-(num_steps*self.batch_size):]  
        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75)

        
        # print(f"Step {step+1}, Average Estimate Reward {torch.mean(self.total_reward_estimate[-num_steps:]):.2f}")
                
                            