import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import torch_geometric
import pathlib

from torch.utils.data import DataLoader, Dataset
from torch.distributions import Categorical
import gzip
import pickle


class SimplePolicy(nn.Module):

    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.linear1 = torch.nn.Sequential(
            nn.Linear(input_dim, output_dim),
        )

        for layer in self.modules():
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.xavier_normal_(layer.weight, gain=1)

    def forward(self, x):

        output = self.linear1(x)
        return output


# class Agent:

#     def __init__(self, policy, device, greedy=True):
#         self.policy = policy
#         self.device = device
#         self.greedy = greedy

#     def select_action(self, state):
#         state = torch.from_numpy(state).float().view(1, -1)

#         if self.device is not None:
#             state.to(self.device)
#         preds = self.policy(state)

#         return preds
    

class Agent:
    def __init__(self, policy, device, greedy=True):
        self.policy = policy
        self.device = device
        self.greedy = greedy

    def select_action(self, state):
        # Ensure all input tensors are on the correct device
        constraint_features = state['constraint_features'].to(self.device)
        edge_indices = state['edge_indices'].to(self.device)
        edge_features = state['edge_features'].to(self.device)
        variable_features = state['variable_features'].to(self.device)

        # Call the policy network
        action_probs = self.policy(constraint_features, edge_indices, edge_features, variable_features)

        # Select action based on greedy or stochastic policy
        if self.greedy:
            action = torch.argmax(action_probs, dim=1)
        else:
            action = torch.multinomial(action_probs, num_samples=1).squeeze(1)

        return action
class BipartiteNodeData(torch_geometric.data.Data):
    """
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(self, constraint_features, edge_indices, edge_features, variable_features, k_init=0):
        super().__init__()
        self.constraint_features = torch.FloatTensor(constraint_features)
        self.edge_index = torch.LongTensor(edge_indices.astype(np.int64))
        self.edge_attr = torch.FloatTensor(edge_features).unsqueeze(1)
        self.variable_features = torch.FloatTensor(variable_features)

        # self.k_init = k_init

    def __inc__(self, key, value):
        """
        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs
        for those entries (edge index) for which this is not obvious.
        """
        if key == 'edge_index':
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        else:
            return super().__inc__(key, value)

class GNNPolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()
        emb_size = 64 # 64
        cons_nfeats = 1
        edge_nfeats = 1
        var_nfeats = 10 # 10

        # CONSTRAINT EMBEDDING
        self.cons_embedding = torch.nn.Sequential(
            # torch.nn.LayerNorm(cons_nfeats),
            torch.nn.Linear(cons_nfeats, emb_size, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size, bias=False),
            torch.nn.ReLU(),
        )

        # EDGE EMBEDDING
        self.edge_embedding = torch.nn.Sequential(
            # torch.nn.LayerNorm(edge_nfeats),
        )

        # VARIABLE EMBEDDING
        self.var_embedding = torch.nn.Sequential(
            # torch.nn.LayerNorm(var_nfeats),
            torch.nn.Linear(var_nfeats, emb_size, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size, bias=False),
            torch.nn.ReLU(),
        )

        self.conv_v_to_c = BipartiteGraphConvolution()
        self.conv_c_to_v = BipartiteGraphConvolution()

        self.output_module = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, 3, bias=False),
        )
        self.pool_activation = torch.nn.Sequential(
            torch.nn.Sigmoid(),
        )

        for layer in self.modules():
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.xavier_normal_(layer.weight, gain=0.2)

    def forward(self, constraint_features, edge_indices, edge_features, variable_features):
        
        reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)

        # First step: linear embedding layers to a common dimension (64)
        constraint_features = self.cons_embedding(constraint_features)
        edge_features = self.edge_embedding(edge_features)
        variable_features = self.var_embedding(variable_features)
        
        # Two half convolutions
        constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features,
                                               constraint_features)

        variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)
        
        # A final MLP on the variable features
        output = self.output_module(variable_features)  #.squeeze(-1)
        output = torch.mean(output, dim=0)  # sum
        output = self.pool_activation(output)

        return output
class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
    """
    The bipartite graph convolution is already provided by pytorch geometric and we merely need
    to provide the exact form of the messages being passed.
    """

    def __init__(self):
        super().__init__('add')
        emb_size = 64 # 64

        self.feature_module_left = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size)
        )
        self.feature_module_edge = torch.nn.Sequential(
            torch.nn.Linear(1, emb_size)
        )
        self.feature_module_right = torch.nn.Sequential(
            torch.nn.Linear(emb_size, emb_size)
        )
        self.feature_module_final = torch.nn.Sequential(
            torch.nn.LayerNorm(emb_size),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size, bias=False)
        )

        self.post_conv_module = torch.nn.Sequential(
            # torch.nn.LayerNorm(emb_size)
        )

        # output_layers
        self.output_module = torch.nn.Sequential(
            torch.nn.Linear(2 * emb_size, emb_size, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_size, emb_size, bias=False),
        )

        for layer in self.modules():
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.xavier_normal_(layer.weight, gain=0.2)
    def forward(self, left_features, edge_indices, edge_features, right_features):
        """
        This method sends the messages, computed in the message method.
        """
        output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
                                node_features=(left_features, right_features), edge_features=edge_features)
        return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))

    def message(self, node_features_i, node_features_j, edge_features):
        output = self.feature_module_final(self.feature_module_left(node_features_j) + self.feature_module_edge(edge_features)
                                          ) # + self.feature_module_edge(edge_features) + self.feature_module_right(node_features_i)
        return output
        
class AgentReinforce(Agent):
    def __init__(self, policy, device, greedy, opt=None, epsilon=0):
        super().__init__(policy, device, greedy)
        self.opt = opt
        self.epsilon = torch.tensor(epsilon, dtype=torch.float)
        self.rewards = []
        self.log_probs = []

    def select_action(self, state):
        graph = state[0]  # 后面改过state后
        # graph = state
        # Ensure all input tensors are on the correct device
        constraint_features = graph.constraint_features
        edge_indices = graph.edge_index
        edge_features = graph.edge_attr
        variable_features = graph.variable_features
        
        # Call the policy network
        action_probs = self.policy(constraint_features, edge_indices, edge_features, variable_features)

        if not self.greedy:
            m = Categorical(action_probs)
            if torch.bernoulli(self.epsilon) == 1 or torch.isnan(action_probs.sum()):
                random_choice = torch.ones_like(action_probs)
                m_rand = Categorical(random_choice)
                action = m_rand.sample()
            else:
                action = m.sample()
            if self.opt is not None:
                self.log_probs.append(m.log_prob(action))
            action = action.item()
        else:
            action = torch.argmax(action_probs, dim=1).item()

        return action
# class AgentReinforce(Agent):

#     def __init__(self, policy, device, greedy, opt=None, epsilon=0):
#         super().__init__(policy, device, greedy)
#         self.opt = opt
#         self.epsilon = torch.tensor(epsilon, dtype=torch.float)

#         self.rewards = []
#         self.log_probs = []

#     def select_action(self, state):
#         preds = super().select_action(state)
#         if not self.greedy:
#             probs = F.log_softmax(preds)
#             m = Categorical(logits=probs)

#             if torch.bernoulli(self.epsilon) == 1 or torch.isnan(probs.exp().sum()):

#                 random_choice = torch.ones(m._num_events)
#                 if self.device is not None:
#                     random_choice.to(self.device)

#                 # print('sum of probs: {}'.format(probs.exp().sum()))
#                 m_rand = Categorical(random_choice)
#                 action = m_rand.sample()
#             else:
#                 action = m.sample()

#             # action = m.sample()
#             if self.opt is not None:
#                 self.log_probs.append(m.log_prob(action))
#             action = action.item()
#         else:
#             action = preds.argmax(1, keepdim=False).item()
#         return action
    
class ImitationLbDataset(Dataset):

    def __init__(self, sample_files, transform=None):
        self.sample_files = sample_files

    def __len__(self):
        return len(self.sample_files)

    def __getitem__(self, index):
        with gzip.open(self.sample_files[index], 'rb') as f:
            sample = pickle.load(f)
        state, label = sample

        state = torch.FloatTensor(state).view(-1)
        label = torch.LongTensor(np.array(label).reshape(-1))
        return state, label
