import torch
import torch.nn as nn
import torch.nn.functional as F
import ipdb
import torch.optim as optim
import numpy as np
from torch.autograd import Variable

def set_learning_rate(optimizer, lr):
    """Sets the learning rate to the given value"""
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

class GraphLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphLayer, self).__init__()
        self.mlp_x = nn.Sequential(
            nn.Linear(2 * hidden_dim + input_dim, 2 * hidden_dim),
            nn.ReLU(),
            nn.Linear(2 * hidden_dim, hidden_dim)
        )
        self.mlp_c = nn.Sequential(
            nn.Linear(2 * hidden_dim + input_dim, 2 * hidden_dim),
            nn.ReLU(),
            nn.Linear(2 * hidden_dim, hidden_dim)
        )
    
    def forward(self, h_x, h_c, edges, F_x, F_c, var_num):
        edges = edges  
        src_x, dst_c = edges

        h_x_agg = torch.zeros_like(h_x).to(h_x.device)
        h_x_agg.index_add_(0, src_x, h_c[dst_c - var_num])

        h_x, h_x_agg, F_x = h_x.squeeze(1), h_x_agg.squeeze(1), F_x.squeeze(1)
        
        h_x_input = torch.cat((h_x_agg, h_x, F_x), dim=1)
        h_x = self.mlp_x(h_x_input)

        h_c_agg = torch.zeros_like(h_c).to(h_c.device)
        h_c_agg.index_add_(0, dst_c - var_num, h_x[src_x])

        h_c, h_c_agg, F_c = h_c.squeeze(1), h_c_agg.squeeze(1), F_c.squeeze(1)

        # 更新节点c的特征
        h_c_input = torch.cat((h_c_agg, h_c, F_c), dim=1)
        h_c = self.mlp_c(h_c_input)
        
        return h_x, h_c

class GNNNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GNNNet, self).__init__()
        self.layers = nn.ModuleList([
            GraphLayer(input_dim, hidden_dim) for _ in range(num_layers)
        ])

        self.mlp_q = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

        self.linear_x = nn.Linear(input_dim, hidden_dim)
        self.linear_c = nn.Linear(input_dim, hidden_dim)

        self.mlp_v = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )
    
    def forward(self, data):
        F_x, F_c, edges,  var_num = data.x, data.c, data.edge_index,  data.var_num
        h_x = self.linear_x(F_x)
        h_c = self.linear_c(F_c)

        for layer in self.layers:
            h_x, h_c = layer(h_x, h_c, edges, F_x, F_c, var_num)
        
        h_x = h_x.view(h_x.size(0)//var_num, var_num, -1)
        
        h_x_sum = h_x.sum(dim=1)

        value = self.mlp_v(h_x_sum)
        # action = action.view(h_x.size(0), h_x.size(1))
        
        # h_x = h_x[action.bool()]
        batch_size, num_vars, feature_dim = h_x.shape

        # Reshape h_sum to match h_x
        h_sum_expanded = h_x_sum.unsqueeze(1)  # Shape becomes [1, 1, 32]
        h_sum_expanded = h_sum_expanded.expand(-1, num_vars, -1)  # Expands to [1, num_vars, 32]

        # Concatenate along the last dimension (feature dimension)
        h_x_combined = torch.cat((h_x, h_sum_expanded), dim=-1) 
        q_value = self.mlp_q(h_x_combined)
        log_act_probs = F.log_softmax(q_value, dim=1)
        return log_act_probs, value

class PolicyValueNet(model = None):
    def __init__(self, model_file=None, device='cuda:1'):
        self.device = device
        self.action_dim = 10
        
        self.q_net = GNNNet(2, 32, 10, 5).to(device)
        if model_file:
            self.load_model(model_file)
        self.target_net = GNNNet(2, 32, 10, 5).to(device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=2e-3, weight_decay=1e-4)
        self.epsilon = 0.9
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.gamma = 0.99
        self.target_update_freq = 10
        self.update_counter = 0
        

    def policy_value_fn(self, state):
        if epsilon is None:
            epsilon = self.epsilon
        if np.random.rand() < epsilon:
            return np.random.randint(self.action_dim)
        else:
            state = state.to(self.device)
            with torch.no_grad():
                log_act_probs, value = self.target_net(state)
                act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
            return act_probs, value

    
    def train_step(self, batch):
        states, actions, rewards, next_states, dones = batch
        states = [s.to(self.device) for s in states]
        actions = torch.tensor(actions, dtype=torch.long, device=self.device)
        rewards = torch.tensor(rewards, dtype=torch.float, device=self.device)
        next_states = [s.to(self.device) for s in next_states]
        dones = torch.tensor(dones, dtype=torch.float, device=self.device)
        
        batch_size = len(states)
        q_values = torch.stack([self.q_net(s) for s in states])
        q_values = q_values[range(batch_size), actions]
        
        with torch.no_grad():
            next_q_values = torch.stack([self.q_net(s) for s in next_states])
            next_actions = next_q_values.argmax(dim=1)
            target_q_values = torch.stack([self.target_net(s) for s in next_states])
            target_q_values = target_q_values[range(batch_size), next_actions]
            targets = rewards + (1 - dones) * self.gamma * target_q_values
        
        loss = F.mse_loss(q_values, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.update_counter += 1
        if self.update_counter % self.target_update_freq == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        return loss.item()
    
    def save_model(self, model_file):
        net_params = self.q_net.state_dict() 
        torch.save(net_params, model_file)

    def load_model(self, model_file):
        net_params = torch.load(model_file, map_location=torch.device('cuda:1'))
        self.q_net.load_state_dict(net_params)