import torch
import ecole as ec
import numpy as np

from agents.model import GNNPolicy,GNNPolicyAno,GNNPolicyLoad


class ObservationFunction(ec.observation.NodeBipartite):

    def __init__(self, problem):
        super().__init__()

    def seed(self, seed):
        pass


class Policy():

    def __init__(self, problem):
        self.rng = np.random.RandomState()

        # get parameters
        params_path = f'agents/trained_models/{problem}/best_params.pkl'
        self.device = f"cuda:0"

        if problem == 'item_placement':
            # set up policy
            self.policy = GNNPolicy().to(self.device)
            self.policy.load_state_dict(torch.load(params_path))
        elif problem == 'anonymous':
            # set up policy
            self.policy = GNNPolicy().to(self.device)
            self.policy.load_state_dict(torch.load(params_path))
        elif problem == 'load_balancing':
            # set up policy
            self.policy = GNNPolicyLoad().to(self.device)
            self.policy.load_state_dict(torch.load(params_path))
        else:
            raise ValueError("this problem is not in ml4co datasets")

    def seed(self, seed):
        self.rng = np.random.RandomState(seed)

    def __call__(self, action_set, observation):
        # mask variable features (no incumbent info)
        variable_features = observation.column_features
        variable_features = np.delete(variable_features, 14, axis=1)
        variable_features = np.delete(variable_features, 13, axis=1)

        constraint_features = torch.FloatTensor(observation.row_features).to(self.device)
        edge_index = torch.LongTensor(observation.edge_features.indices.astype(np.int64)).to(self.device)
        edge_attr = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1)).to(self.device)
        variable_features = torch.FloatTensor(variable_features).to(self.device)
        action_set = torch.LongTensor(np.array(action_set, dtype=np.int64)).to(self.device)

        logits = self.policy(constraint_features, edge_index, edge_attr, variable_features)
        logits = logits[action_set]
        action_idx = logits.argmax().item()
        action = action_set[action_idx]

        return action
